Пример #1
0
class FUNITModel(nn.Module):
    def __init__(self, hp):
        super(FUNITModel, self).__init__()
        self.gen = FewShotGen(hp['gen'])
        self.dis = GPPatchMcResDis(hp['dis'])
        self.gen_test = copy.deepcopy(self.gen)

    def forward(self, co_data, cl_data, hp, mode):
        xa = co_data[0].cuda()
        la = co_data[1].cuda()
        xb = cl_data[0].cuda()
        lb = cl_data[1].cuda()
        if mode == 'gen_update':
            c_xa = self.gen.enc_content(xa)
            s_xa = self.gen.enc_class_model(xa)
            s_xb = self.gen.enc_class_model(xb)
            xt = self.gen.decode(c_xa, s_xb)  # translation
            xr = self.gen.decode(c_xa, s_xa)  # reconstruction
            l_adv_t, gacc_t, xt_gan_feat = self.dis.calc_gen_loss(xt, lb)
            l_adv_r, gacc_r, xr_gan_feat = self.dis.calc_gen_loss(xr, la)
            _, xb_gan_feat = self.dis(xb, lb)
            _, xa_gan_feat = self.dis(xa, la)
            l_c_rec = recon_criterion(
                xr_gan_feat.mean(3).mean(2),
                xa_gan_feat.mean(3).mean(2))
            l_m_rec = recon_criterion(
                xt_gan_feat.mean(3).mean(2),
                xb_gan_feat.mean(3).mean(2))
            l_x_rec = recon_criterion(xr, xa)
            l_adv = 0.5 * (l_adv_t + l_adv_r)
            acc = 0.5 * (gacc_t + gacc_r)
            l_total = (hp['gan_w'] * l_adv + hp['r_w'] * l_x_rec + hp['fm_w'] *
                       (l_c_rec + l_m_rec))
            l_total.backward()
            return l_total, l_adv, l_x_rec, l_c_rec, l_m_rec, acc
        elif mode == 'dis_update':
            xb.requires_grad_()
            l_real_pre, acc_r, resp_r = self.dis.calc_dis_real_loss(xb, lb)
            l_real = hp['gan_w'] * l_real_pre
            l_real.backward(retain_graph=True)
            l_reg_pre = self.dis.calc_grad2(resp_r, xb)
            l_reg = 10 * l_reg_pre
            l_reg.backward()
            with torch.no_grad():
                c_xa = self.gen.enc_content(xa)
                s_xb = self.gen.enc_class_model(xb)
                xt = self.gen.decode(c_xa, s_xb)
            l_fake_p, acc_f, resp_f = self.dis.calc_dis_fake_loss(
                xt.detach(), lb)
            l_fake = hp['gan_w'] * l_fake_p
            l_fake.backward()
            l_total = l_fake + l_real + l_reg
            acc = 0.5 * (acc_f + acc_r)
            return l_total, l_fake_p, l_real_pre, l_reg_pre, acc
        else:
            assert 0, 'Not support operation'

    def test(self, co_data, cl_data):
        self.eval()
        self.gen.eval()
        self.gen_test.eval()
        xa = co_data[0].cuda()
        xb = cl_data[0].cuda()
        c_xa_current = self.gen.enc_content(xa)
        s_xa_current = self.gen.enc_class_model(xa)
        s_xb_current = self.gen.enc_class_model(xb)
        xt_current = self.gen.decode(c_xa_current, s_xb_current)
        xr_current = self.gen.decode(c_xa_current, s_xa_current)
        c_xa = self.gen_test.enc_content(xa)
        s_xa = self.gen_test.enc_class_model(xa)
        s_xb = self.gen_test.enc_class_model(xb)
        xt = self.gen_test.decode(c_xa, s_xb)
        xr = self.gen_test.decode(c_xa, s_xa)
        self.train()
        return xa, xr_current, xt_current, xb, xr, xt

    def translate_k_shot(self, co_data, cl_data, k):
        self.eval()
        xa = co_data[0].cuda()
        xb = cl_data[0].cuda()
        c_xa_current = self.gen_test.enc_content(xa)
        if k == 1:
            c_xa_current = self.gen_test.enc_content(xa)
            s_xb_current = self.gen_test.enc_class_model(xb)
            xt_current = self.gen_test.decode(c_xa_current, s_xb_current)
        else:
            s_xb_current_before = self.gen_test.enc_class_model(xb)
            s_xb_current_after = s_xb_current_before.squeeze(-1).permute(
                1, 2, 0)
            s_xb_current_pool = torch.nn.functional.avg_pool1d(
                s_xb_current_after, k)
            s_xb_current = s_xb_current_pool.permute(2, 0, 1).unsqueeze(-1)
            xt_current = self.gen_test.decode(c_xa_current, s_xb_current)
        return xt_current

    def compute_k_style(self, style_batch, k):
        self.eval()
        style_batch = style_batch.cuda()
        s_xb_before = self.gen_test.enc_class_model(style_batch)
        s_xb_after = s_xb_before.squeeze(-1).permute(1, 2, 0)
        s_xb_pool = torch.nn.functional.avg_pool1d(s_xb_after, k)
        s_xb = s_xb_pool.permute(2, 0, 1).unsqueeze(-1)
        return s_xb

    def translate_simple(self, content_image, class_code):
        self.eval()
        xa = content_image.cuda()
        s_xb_current = class_code.cuda()
        c_xa_current = self.gen_test.enc_content(xa)
        xt_current = self.gen_test.decode(c_xa_current, s_xb_current)
        return xt_current
Пример #2
0
class SEMIT(nn.Module):
    def __init__(self, hp):
        super(SEMIT, self).__init__()
        self.gen = FewShotGen(hp['gen'])
        self.dis = GPPatchMcResDis(hp['dis'])
        self.gen_test = copy.deepcopy(self.gen)
        self.pooling_hf = Avgpool()
        self.logsoftmax_hf = nn.LogSoftmax(dim=1).cuda()
        self.softmax_hf = nn.Softmax(dim=1).cuda()
        self.pooling_lf = Avgpool(kernel_size=2, stride=2)
        self.logsoftmax_lf = nn.LogSoftmax(dim=1).cuda()
        self.softmax_lf = nn.Softmax(dim=1).cuda()
    def forward(self, co_data, cl_data, octave_alpha, hp, mode, constant_octave=0.25):
        #pdb.set_trace()
        xa = co_data[0].cuda()
        la = co_data[1].cuda()
        xb = cl_data[0].cuda()
        lb = cl_data[1].cuda()
        if mode == 'gen_update':
            c_xa = self.gen.enc_content(xa, alpha_in=octave_alpha, alpha_out=octave_alpha)
            s_xa = self.gen.enc_class_model(xa, alpha_in=octave_alpha, alpha_out=octave_alpha)
            s_xb = self.gen.enc_class_model(xb, alpha_in=octave_alpha, alpha_out=octave_alpha)
            xt = self.gen.decode(c_xa, s_xb, octave_alpha)  # translation
            xr = self.gen.decode(c_xa, s_xa, octave_alpha)  # reconstruction
            l_adv_t, gacc_t, xt_gan_feat = self.dis.calc_gen_loss(xt, lb, constant_octave)
            l_adv_r, gacc_r, xr_gan_feat = self.dis.calc_gen_loss(xr, la, constant_octave)
            _, xb_gan_feat = self.dis(xb, lb, alpha_in=constant_octave, alpha_out=constant_octave)
            _, xa_gan_feat = self.dis(xa, la, alpha_in=constant_octave, alpha_out=constant_octave)
            # entropy loss
            l_e = entropy_loss(c_xa, (self.pooling_hf, self.pooling_lf), (self.softmax_hf, self.softmax_lf), (self.logsoftmax_hf, self.logsoftmax_lf))
            c_xt = self.gen.enc_content(xt, alpha_in=octave_alpha, alpha_out=octave_alpha)
            xr_cyc = self.gen.decode(c_xt, s_xa, octave_alpha)
            l_x_rec_cyc = recon_criterion(xr_cyc, xa)
            l_c_rec = recon_criterion(xr_gan_feat.mean(3).mean(2),
                                      xa_gan_feat.mean(3).mean(2))
            l_m_rec = recon_criterion(xt_gan_feat.mean(3).mean(2),
                                      xb_gan_feat.mean(3).mean(2))
            l_x_rec = recon_criterion(xr, xa)
            ## rec loss + cycle loss
            l_x_rec = l_x_rec + 1.*l_x_rec_cyc
            l_adv = 0.5 * (l_adv_t + l_adv_r)
            acc = 0.5 * (gacc_t + gacc_r)
            l_total = (hp['gan_w'] * l_adv + hp['r_w'] * l_x_rec + hp[
                'fm_w'] * (l_c_rec + l_m_rec)) + 0.01 * l_e
            l_total.backward()
            return l_total, l_adv, l_x_rec, l_c_rec, l_m_rec, acc
        elif mode == 'dis_update':
            xb.requires_grad_()
            #In Disc I use constant octave:  constant_octave = 0.25
            l_real_pre, acc_r, resp_r = self.dis.calc_dis_real_loss(xb, lb, constant_octave)
            l_real = hp['gan_w'] * l_real_pre
            l_real.backward(retain_graph=True)
            l_reg_pre = self.dis.calc_grad2(resp_r, xb)
            l_reg = 10 * l_reg_pre
            l_reg.backward()
            with torch.no_grad():
                c_xa = self.gen.enc_content(xa, alpha_in=octave_alpha, alpha_out=octave_alpha)
                s_xb = self.gen.enc_class_model(xb, alpha_in=octave_alpha, alpha_out=octave_alpha)
                xt = self.gen.decode(c_xa, s_xb, octave_alpha)
            l_fake_p, acc_f, resp_f = self.dis.calc_dis_fake_loss(xt.detach(),
                                                                  lb, constant_octave)
            l_fake = hp['gan_w'] * l_fake_p
            l_fake.backward()
            l_total = l_fake + l_real + l_reg
            acc = 0.5 * (acc_f + acc_r)
            return l_total, l_fake_p, l_real_pre, l_reg_pre, acc
        else:
            assert 0, 'Not support operation'


    def test(self, co_data, cl_data):
        self.eval()
        self.gen.eval()
        self.gen_test.eval()
        xa = co_data[0].cuda()
        xb = cl_data[0].cuda()
        for octave_alpha_value_index in range(11):
            octave_alpha_value = octave_alpha_value_index / 10.
            alpha_in, alpha_out = octave_alpha_value, octave_alpha_value 

            c_xa_current = self.gen.enc_content(xa, alpha_in=alpha_in, alpha_out=alpha_out)
            s_xa_current = self.gen.enc_class_model(xa, alpha_in=alpha_in, alpha_out=alpha_out)
            s_xb_current = self.gen.enc_class_model(xb, alpha_in=alpha_in, alpha_out=alpha_out)
            xt_current = self.gen.decode(c_xa_current, s_xb_current, octave_alpha_value)
            xr_current = self.gen.decode(c_xa_current, s_xa_current, octave_alpha_value)
            c_xa = self.gen_test.enc_content(xa, alpha_in=alpha_in, alpha_out=alpha_out)
            s_xa = self.gen_test.enc_class_model(xa, alpha_in=alpha_in, alpha_out=alpha_out)
            s_xb = self.gen_test.enc_class_model(xb, alpha_in=alpha_in, alpha_out=alpha_out)
            xt = self.gen_test.decode(c_xa, s_xb, octave_alpha_value)
            xr = self.gen_test.decode(c_xa, s_xa, octave_alpha_value)

            if octave_alpha_value_index==0:
               xt_current_set = [xt_current]  
               xr_current_set = [xr_current]
               xt_set = [xt]  
               xr_set = [xr]
            else:
               xt_current_set.append(xt_current)  
               xr_current_set.append(xr_current)
               xt_set.append(xt)  
               xr_set.append(xr)
        self.train()
        #return xa, xr_current, xt_current, xb, xr, xt
        return xa, xr_current_set[5], xt_current_set[5], xb, xr_set[5], xt_set[0], xt_set[1], xt_set[2], xt_set[3],xt_set[4], xt_set[5],xt_set[6], xt_set[7],xt_set[8], xt_set[9], xt_set[10]

    def translate_k_shot(self, co_data, cl_data, k):
        self.eval()
        xa = co_data[0].cuda()
        xb = cl_data[0].cuda()
        c_xa_current = self.gen_test.enc_content(xa)
        if k == 1:
            c_xa_current = self.gen_test.enc_content(xa)
            s_xb_current = self.gen_test.enc_class_model(xb)
            xt_current = self.gen_test.decode(c_xa_current, s_xb_current)
        else:
            s_xb_current_before = self.gen_test.enc_class_model(xb)
            s_xb_current_after = s_xb_current_before.squeeze(-1).permute(1,
                                                                         2,
                                                                         0)
            s_xb_current_pool = torch.nn.functional.avg_pool1d(
                s_xb_current_after, k)
            s_xb_current = s_xb_current_pool.permute(2, 0, 1).unsqueeze(-1)
            xt_current = self.gen_test.decode(c_xa_current, s_xb_current)
        return xt_current

    def compute_k_style(self, style_batch, k):
        self.eval()
        style_batch = style_batch.cuda()
        s_xb_before = self.gen_test.enc_class_model(style_batch)
        s_xb_after = s_xb_before.squeeze(-1).permute(1, 2, 0)
        s_xb_pool = torch.nn.functional.avg_pool1d(s_xb_after, k)
        s_xb = s_xb_pool.permute(2, 0, 1).unsqueeze(-1)
        return s_xb

    def translate_simple(self, content_image, class_code):
        self.eval()
        xa = content_image.cuda()
        s_xb_current = class_code.cuda()
        c_xa_current = self.gen_test.enc_content(xa)
        xt_current = self.gen_test.decode(c_xa_current, s_xb_current)
        return xt_current
Пример #3
0
class G2GModel(nn.Module):
    def __init__(self, hp):
        super(G2GModel, self).__init__()
        self.gen = FewShotGen(hp['gen'])
        self.dis = GPPatchMcResDis(hp['dis'])
        self.gen_test = copy.deepcopy(self.gen)

    def cross_forward(self, xa, xb, mode):
        if mode == 'train':
            gen = self.gen
        else:
            gen = self.gen_test

        # content and class codes image xa (meerkat)
        xa_cont = gen.enc_content(xa)
        xa_class = gen.enc_class_model(xa)

        # content and class codes image xb (dog)
        xb_cont = gen.enc_content(xb)
        xb_class = gen.enc_class_model(xb)

        # mixed images
        mb = gen.decode(xa_cont, xb_class)  # translated dog
        ma = gen.decode(xb_cont, xa_class)  # translated meerkat

        # reconstruction with itself
        xa_rec = gen.decode(xa_cont,
                            xa_class)  # reconstruction of image a (meerkat)
        xb_rec = gen.decode(xb_cont,
                            xb_class)  # reconstruction of image b (dog)

        # content and class codes of mixed image mb (translated dog)
        mb_cont = gen.enc_content(mb)
        mb_class = gen.enc_class_model(mb)

        # content and class codes of mixed image ma (translated meerkat)
        ma_cont = gen.enc_content(ma)
        ma_class = gen.enc_class_model(ma)

        # reconstruction after reassembly stage
        ra = gen.decode(mb_cont, ma_class)  # meerkat
        rb = gen.decode(ma_cont, mb_class)  # dog
        """
        Returns (using meerkat dog example):
        Let meerkat be xa, and dog be xb, then
            - xa_cont: content code of meerkat
            - xa_class: class code of meerkat
            - xb_cont: content code of dog
            - xb_class: class code of dog
            - mb: translated dog
            - ma: translated meerkat
            - xa_rec: short reconstruction meerkat
            - xb_rec: short reconstruction dog
            - mb_cont: content code of translated dog
            - mb_class: class code of translated dog
            - ma_cont: content code of translated meerkat
            - ma_class: class code of translated meerkat
            - ra: fully reconstructed meerkat
            - rb: fully reconstructed dog
        """
        return {
            "xa_cont": xa_cont,
            "xa_class": xa_class,
            "xb_cont": xb_cont,
            "xb_class": xb_class,
            "mb": mb,
            "ma": ma,
            "xa_rec": xa_rec,
            "xb_rec": xb_rec,
            "mb_cont": mb_cont,
            "mb_class": mb_class,
            "ma_cont": ma_cont,
            "ma_class": ma_class,
            "ra": ra,
            "rb": rb,
        }

    def calc_g_loss(self, out, xa, xb, la, lb):
        # adversarial loss, generator accuracy and features
        l_adv_mb, gacc_mb, mb_gan_feat = self.dis.calc_gen_loss(
            out["mb"], lb
        )  # calc_gen_loss returns loss, accuracy and gan_feat of only first param, i.e. xt
        l_adv_ma, gacc_ma, ma_gan_feat = self.dis.calc_gen_loss(
            out["ma"], la
        )  # calc_gen_loss returns loss, accuracy and gan_feat of only first param, i.e. xt

        l_adv_xa_rec, gacc_xa_rec, xa_rec_gan_feat = self.dis.calc_gen_loss(
            out["xa_rec"], la)
        l_adv_xb_rec, gacc_xb_rec, xb_rec_gan_feat = self.dis.calc_gen_loss(
            out["xb_rec"], lb)

        # extracting features for the feature matching loss
        _, xb_gan_feat = self.dis(xb, lb)
        _, xa_gan_feat = self.dis(xa, la)

        # feature matching loss
        l_fm_xa_rec = recon_criterion(
            xa_rec_gan_feat.mean(3).mean(2),
            xa_gan_feat.mean(3).mean(2))
        l_fm_xb_rec = recon_criterion(
            xb_rec_gan_feat.mean(3).mean(2),
            xb_gan_feat.mean(3).mean(2))
        l_fm_rec = 0.5 * (l_fm_xa_rec + l_fm_xb_rec)

        l_fm_mb = recon_criterion(
            mb_gan_feat.mean(3).mean(2),
            xb_gan_feat.mean(3).mean(2))
        l_fm_ma = recon_criterion(
            ma_gan_feat.mean(3).mean(2),
            xa_gan_feat.mean(3).mean(2))
        l_fm_m = 0.5 * (l_fm_ma + l_fm_mb)

        # short reconstruction loss
        l_rec_xa = recon_criterion(out["xa_rec"], xa)
        l_rec_xb = recon_criterion(out["xb_rec"], xb)
        l_rec = 0.5 * (l_rec_xa + l_rec_xb)

        # long L1 reconstruction loss
        l_long_rec_xa = recon_criterion(out["ra"], xa)
        l_long_rec_xb = recon_criterion(out["rb"], xb)
        l_long_rec = 0.5 * (l_long_rec_xa + l_long_rec_xb)

        # long feature matching loss
        _, gacc_long_fm_xa, ra_gan_feat = self.dis.calc_gen_loss(out["ra"], la)
        _, gacc_long_fm_xb, rb_gan_feat = self.dis.calc_gen_loss(out["rb"], lb)
        l_long_fm_xa = recon_criterion(
            ra_gan_feat.mean(3).mean(2),
            xa_gan_feat.mean(3).mean(2))
        l_long_fm_xb = recon_criterion(
            rb_gan_feat.mean(3).mean(2),
            xb_gan_feat.mean(3).mean(2))
        l_long_fm = 0.5 * (l_long_fm_xa + l_long_fm_xb)

        # Feature matching loss in second G2G stage: between the mixed image and the reconstructed image of the same class
        l_fm_mix_rec_a = recon_criterion(
            ra_gan_feat.mean(3).mean(2),
            ma_gan_feat.mean(3).mean(
                2))  # compare reconstructed meerkat with mixed meerkat
        l_fm_mix_rec_b = recon_criterion(
            rb_gan_feat.mean(3).mean(2),
            mb_gan_feat.mean(3).mean(
                2))  # compare reconstructed dog with mixed dog
        l_fm_mix_rec = 0.5 * (l_fm_mix_rec_a + l_fm_mix_rec_b)

        # adversarial loss for
        l_adv = 0.25 * (l_adv_ma + l_adv_mb + l_adv_xa_rec + l_adv_xb_rec)

        # accuracy
        acc = 0.25 * (gacc_ma + gacc_mb + gacc_xa_rec + gacc_xb_rec)

        # overall loss: adversarial, reconstruction and feature matching reconstruction, feature matching loss and accuracy
        return l_adv, l_rec, l_fm_rec, l_fm_m, l_long_rec, l_long_fm, l_fm_mix_rec, acc

    def calc_d_loss(self, xa, xb, la, lb, gan_weight, reg_weight):
        # calculate discriminator's real loss
        l_real_pre_a, acc_r_a, resp_r_a = self.dis.calc_dis_real_loss(xa, la)
        l_real_pre_b, acc_r_b, resp_r_b = self.dis.calc_dis_real_loss(xb, lb)
        l_real_pre = 0.5 * (l_real_pre_a + l_real_pre_b)
        l_real = gan_weight * l_real_pre
        l_real.backward(retain_graph=True)

        # real gradient penalty regularization proposed by Mescheder et al.
        l_reg_pre_a = self.dis.calc_grad2(resp_r_a, xa)
        l_reg_pre_b = self.dis.calc_grad2(resp_r_b, xb)
        l_reg_pre = 0.5 * (l_reg_pre_a + l_reg_pre_b)
        l_reg = reg_weight * l_reg_pre
        l_reg.backward()

        # generate images for the discriminator to classify
        with torch.no_grad():
            xa_cont = self.gen.enc_content(xa)  # meerkat
            xb_cont = self.gen.enc_content(xb)  # dog

            xa_class = self.gen.enc_class_model(xa)
            xb_class = self.gen.enc_class_model(xb)

            mb = self.gen.decode(xa_cont, xb_class)  # dog
            ma = self.gen.decode(xb_cont, xa_class)  # meerkat

        # calculate discriminator's fake loss
        l_fake_pre_a, acc_f_a, resp_f_a = self.dis.calc_dis_fake_loss(
            ma.detach(), lb)  # meerkat
        l_fake_pre_b, acc_f_b, resp_f_b = self.dis.calc_dis_fake_loss(
            mb.detach(), la)  # dog
        l_fake_pre = 0.5 * (l_fake_pre_a + l_fake_pre_b)
        l_fake = gan_weight * l_fake_pre
        l_fake.backward()

        acc = 0.25 * (acc_f_a + acc_f_b + acc_r_a + acc_r_b)
        return l_fake_pre, l_real_pre, l_reg_pre, acc

    def forward(self, co_data, cl_data, hp, mode):
        """
        Params:
            - co_data: content data with one content image and one content image label
            - cl_data: class data with one class image and one class image label
            - hp: hyperparameters, more specifically weights for the losses
            - mode: discriminator or generator update step
        Returns:
            Generator:
                - l_total: overall generator loss
                - l_adv: adversarial loss of generator
                - l_x_rec: reconstruction loss
                - l_c_rec: feature matching loss same image
                - l_m_rec: feature matching loss translated image
                - acc: generator accuracy
            Discriminator:
                - l_total: overall discriminator loss
                - l_fake_p: discriminator fake loss
                - l_real_pre: discriminator real loss
                - l_reg_pre: real gradient penalty regularization loss term
                - acc: discriminator accuracy
        """
        xa = co_data[0].cuda()
        la = co_data[1].cuda()
        xb = cl_data[0].cuda()
        lb = cl_data[1].cuda()

        if mode == 'gen_update':

            # forward pass
            out = self.cross_forward(xa, xb, 'train')

            l_adv, l_rec, l_fm_rec, l_fm_m, l_long_rec, l_long_fm, l_fm_mix_rec, acc = self.calc_g_loss(
                out, xa, xb, la, lb)

            # overall loss: adversarial, reconstruction and feature matching loss
            l_total = (hp['gan_w'] * l_adv + hp['r_w'] * l_rec +
                       hp['fm_rec_w'] * l_fm_rec + hp['fm_w'] * l_fm_m +
                       hp['rl_w'] * l_long_rec + hp['fml_w'] * l_long_fm +
                       hp['fml_mix_rec_w'] * l_fm_mix_rec)
            l_total.backward()

            return l_total, l_adv, l_rec, l_fm_rec, l_fm_m, l_long_rec, l_long_fm, l_fm_mix_rec, acc

        elif mode == 'dis_update':
            # for the gradient penalty regularization
            xa.requires_grad_()
            xb.requires_grad_()

            l_fake, l_real, l_reg, acc = self.calc_d_loss(
                xa, xb, la, lb, hp['gan_w'], hp['reg_w'])

            # overall loss: fake, real and regularization loss term
            l_total = hp['gan_w'] * (l_fake + l_real) + l_reg

            return l_total, l_fake, l_real, l_reg, acc
        else:
            assert 0, 'Not support operation'

    def test(self, co_data, cl_data):
        """
        Params:
            - co_data: content data with one content image and one content image label
            - cl_data: class data with one class image and one class image label                    
        Returns:
            - xa: original image meerkat
            - xb: original image dog
            - mb: mixed image dog in meerkat position
            - ma: mixed image meerkat in dog position
            - ra: reconstructed image meerkat
            - rb: reconstructed image dog
        """
        self.eval()
        self.gen.eval()
        self.gen_test.eval()

        xa = co_data[0].cuda()  # meerkat
        xb = cl_data[0].cuda()  # dog

        out = self.cross_forward(xa, xb, 'test')

        self.train()

        return xa, xb, out['mb'], out['ma'], out['ra'], out['rb']

    def translate_k_shot(self, co_data, cl_data, k):
        """
        Params:
            - co_data: content data with one content image and one content image label
            - cl_data: class data with one class image and one class image label
            - k: number of shots to generate a translated image
        Returns:
            - xt_current: translated image at current training state of model
            - rec_current: reassembled image at current training state of model
        """
        self.eval()

        # for training on GPU
        xa = co_data[0].cuda()
        xb = cl_data[0].cuda()

        xa_cont_current = self.gen_test.enc_content(xa)

        # perform translation for k shots
        if k == 1:
            xa_cont_current = self.gen_test.enc_content(xa)
            xb_class_current = self.gen_test.enc_class_model(xb)
            xt_current = self.gen_test.decode(xa_cont_current,
                                              xb_class_current)
        else:
            xb_class_current_before = self.gen_test.enc_class_model(xb)
            xb_class_current_after = xb_class_current_before.squeeze(
                -1).permute(1, 2, 0)
            xb_class_current_pool = torch.nn.functional.avg_pool1d(
                xb_class_current_after, k)
            xb_class_current = xb_class_current_pool.permute(2, 0,
                                                             1).unsqueeze(-1)
            xt_current = self.gen_test.decode(xa_cont_current,
                                              xb_class_current)

        return xt_current

    def compute_k_style(self, style_batch, k):
        self.eval()
        style_batch = style_batch.cuda()
        xb_class_before = self.gen_test.enc_class_model(style_batch)
        xb_class_after = xb_class_before.squeeze(-1).permute(1, 2, 0)
        xb_class_pool = torch.nn.functional.avg_pool1d(xb_class_after, k)
        xb_class = xb_class_pool.permute(2, 0, 1).unsqueeze(-1)
        return xb_class

    def translate_simple(self, content_image, class_code):
        self.eval()
        xa = content_image.cuda()
        xb_class_current = class_code.cuda()
        xa_cont_current = self.gen_test.enc_content(xa)
        xt_current = self.gen_test.decode(xa_cont_current, xb_class_current)
        return xt_current

    def translate_cross(self, content_image_a, content_image_b):
        """
        Params:
            - content_image_a
            - content_image_b
            - class_code_a
            - class_code_b
        Returns:
            - out: dictionary with all intermediate images and codes
        """
        self.eval()

        xa = content_image_a.cuda()
        xb = content_image_b.cuda()

        # forward pass
        out = self.cross_forward(xa, xb, 'test')

        return out
Пример #4
0
class FUNITModel(nn.Module):
    def __init__(self, hp):
        super(FUNITModel, self).__init__()
        self.gen = FewShotGen(hp['gen'])
        self.dis = GPPatchMcResDis(hp['dis'])
        self.gen_test = copy.deepcopy(self.gen)

    def forward(self, co_data, cl_data, hp, mode):
        """
        Params:
            - co_data: content data with one content image and one content image label
            - cl_data: class data with one class image and one class image label
            - hp: hyperparameters, more specifically weights for the losses
            - mode: discriminator or generator update step
        Returns:
            Generator:
                - l_total: overall generator loss
                - l_adv: adversarial loss of generator
                - l_x_rec: reconstruction loss
                - l_c_rec: feature matching loss same image
                - l_m_rec: feature matching loss translated image
                - acc: generator accuracy
            Discriminator:
                - l_total: overall discriminator loss
                - l_fake_p: discriminator fake loss
                - l_real_pre: discriminator real loss
                - l_reg_pre: real gradient penalty regularization loss term
                - acc: discriminator accuracy
        """
        xa = co_data[0].cuda()
        la = co_data[1].cuda()
        xb = cl_data[0].cuda()
        lb = cl_data[1].cuda()

        if mode == 'gen_update':

            # forward pass
            c_xa = self.gen.enc_content(xa)  # xa is content code of meerkat
            s_xa = self.gen.enc_class_model(xa)  # class code of meerkat
            s_xb = self.gen.enc_class_model(xb)  # xb = dog
            xt = self.gen.decode(c_xa, s_xb)  # translated dog
            xr = self.gen.decode(c_xa, s_xa)  # reconstructed meerkat

            # adversarial loss, generator accuracy and features
            # calc_gen_loss returns loss, accuracy and gan_feat of only first param, i.e. xt
            l_adv_t, gacc_t, xt_gan_feat = self.dis.calc_gen_loss(xt, lb)  #
            l_adv_r, gacc_r, xr_gan_feat = self.dis.calc_gen_loss(xr, la)

            # extracting features for the feature matching loss
            _, xb_gan_feat = self.dis(
                xb, lb)  # xb_gan_feat are the features of the originl meerkat
            _, xa_gan_feat = self.dis(xa, la)

            # feature matching loss
            l_c_rec = recon_criterion(
                xr_gan_feat.mean(3).mean(2),
                xa_gan_feat.mean(3).mean(2))
            l_m_rec = recon_criterion(
                xt_gan_feat.mean(3).mean(2),
                xb_gan_feat.mean(3).mean(2))

            # short reconstruction loss
            l_x_rec = recon_criterion(xr, xa)

            # adversarial loss for
            l_adv = 0.5 * (l_adv_t + l_adv_r)

            # accuracy
            acc = 0.5 * (gacc_t + gacc_r)

            # overall loss: adversarial, reconstruction and feature matching loss
            l_total = (hp['gan_w'] * l_adv + hp['r_w'] * l_x_rec + hp['fm_w'] *
                       (l_c_rec + l_m_rec))
            l_total.backward()
            return l_total, l_adv, l_x_rec, l_c_rec, l_m_rec, acc

        elif mode == 'dis_update':
            xb.requires_grad_()

            # calculate discriminator's real loss
            l_real_pre, acc_r, resp_r = self.dis.calc_dis_real_loss(xb, lb)
            l_real = hp['gan_w'] * l_real_pre
            l_real.backward(retain_graph=True)

            # real gradient penalty regularization proposed by Mescheder et al.
            l_reg_pre = self.dis.calc_grad2(resp_r, xb)
            l_reg = 10 * l_reg_pre
            l_reg.backward()

            # generate images for the discriminator to classify
            with torch.no_grad():
                c_xa = self.gen.enc_content(xa)
                s_xb = self.gen.enc_class_model(xb)
                xt = self.gen.decode(c_xa, s_xb)

            # calculate discriminator's fake loss
            l_fake_p, acc_f, resp_f = self.dis.calc_dis_fake_loss(
                xt.detach(), lb)
            l_fake = hp['gan_w'] * l_fake_p
            l_fake.backward()
            l_total = l_fake + l_real + l_reg

            acc = 0.5 * (acc_f + acc_r)
            return l_total, l_fake_p, l_real_pre, l_reg_pre, acc
        else:
            assert 0, 'Not support operation'

    def test(self, co_data, cl_data):
        """
        Params:
            - co_data: content data with one content image and one content image label
            - cl_data: class data with one class image and one class image label
        """
        self.eval()
        self.gen.eval()
        self.gen_test.eval()

        xa = co_data[0].cuda()
        xb = cl_data[0].cuda()

        c_xa_current = self.gen.enc_content(xa)
        s_xa_current = self.gen.enc_class_model(xa)
        s_xb_current = self.gen.enc_class_model(xb)
        xt_current = self.gen.decode(c_xa_current, s_xb_current)
        xr_current = self.gen.decode(c_xa_current, s_xa_current)

        c_xa = self.gen_test.enc_content(xa)
        s_xa = self.gen_test.enc_class_model(xa)
        s_xb = self.gen_test.enc_class_model(xb)
        xt = self.gen_test.decode(c_xa, s_xb)
        xr = self.gen_test.decode(c_xa, s_xa)

        self.train()
        """
        Returns:
            - xa: original image of domain A
            - xr : reconstruction of two same images
            - xt : translated mixed image
            - xb: original image of domain B
            - xr: test reconstruction
            - xt: test translation
        """
        return xa, xr_current, xt_current, xb, xr, xt

    def translate_k_shot(self, co_data, cl_data, k):
        """
        Params:
            - co_data: content data with one content image and one content image label
            - cl_data: class data with one class image and one class image label
            - k: number of shots to generate a translated image
        """
        self.eval()

        # for training on GPU
        xa = co_data[0].cuda()
        xb = cl_data[0].cuda()

        c_xa_current = self.gen_test.enc_content(xa)

        # perform translation for k shots
        if k == 1:
            c_xa_current = self.gen_test.enc_content(xa)
            s_xb_current = self.gen_test.enc_class_model(xb)
            xt_current = self.gen_test.decode(c_xa_current, s_xb_current)
        else:
            s_xb_current_before = self.gen_test.enc_class_model(xb)
            s_xb_current_after = s_xb_current_before.squeeze(-1).permute(
                1, 2, 0)
            s_xb_current_pool = torch.nn.functional.avg_pool1d(
                s_xb_current_after, k)
            s_xb_current = s_xb_current_pool.permute(2, 0, 1).unsqueeze(-1)
            xt_current = self.gen_test.decode(c_xa_current, s_xb_current)

        return xt_current

    def compute_k_style(self, style_batch, k):
        self.eval()
        style_batch = style_batch.cuda()
        s_xb_before = self.gen_test.enc_class_model(style_batch)
        s_xb_after = s_xb_before.squeeze(-1).permute(1, 2, 0)
        s_xb_pool = torch.nn.functional.avg_pool1d(s_xb_after, k)
        s_xb = s_xb_pool.permute(2, 0, 1).unsqueeze(-1)
        return s_xb

    def translate_simple(self, content_image, class_code):
        self.eval()
        xa = content_image.cuda()
        s_xb_current = class_code.cuda()
        c_xa_current = self.gen_test.enc_content(xa)
        xt_current = self.gen_test.decode(c_xa_current, s_xb_current)
        return xt_current
Пример #5
0
class FUNITModel(nn.Module):
    def __init__(self, hp):
        super(FUNITModel, self).__init__()
        self.gen_a = FewShotGen(hp['gen_a'])  # human domain Generator
        self.gen_b = FewShotGen(hp['gen_b'])  # anime domain Generator
        self.dis_a = GPPatchMcResDis(hp['dis_a'])  # human domain Discriminator
        self.dis_b = GPPatchMcResDis(hp['dis_b'])  # anime domain Discriminator

        self.gen_test_a = copy.deepcopy(self.gen_a)
        self.gen_test_b = copy.deepcopy(self.gen_b)

    def forward(self, co_data, cl_data, hp, mode):
        xa = co_data[0].cuda()
        la = co_data[1].cuda()
        xb = cl_data[0].cuda()
        lb = cl_data[1].cuda()
        if mode == 'gen_update':
            # get the content and style code for human domain and anime domain
            c_xa = self.gen_a.enc_content(xa)
            c_xb = self.gen_b.enc_content(xb)
            s_xa = self.gen_a.enc_class_model(xa)
            s_xb = self.gen_b.enc_class_model(xb)
            # reconstruction
            xr_a = self.gen_a.decode(c_xa, s_xa)
            xr_b = self.gen_b.decode(c_xb, s_xb)
            # translation
            xt_a2b = self.gen_b.decode(c_xa, s_xb)
            xt_b2a = self.gen_a.decode(c_xb, s_xa)
            # recode
            c_xt_a2b = self.gen_b.enc_content(xt_a2b)
            c_xt_b2a = self.gen_a.enc_content(xt_b2a)
            s_xt_a2b = self.gen_b.enc_class_model(xt_a2b)
            s_xt_b2a = self.gen_a.enc_class_model(xt_b2a)

            ############ caculate loss ############
            # gan loss
            xt_a2b_gan_loss, xt_a2b_gan_acc, xt_a2b_gan_feat = self.dis_b.calc_gan_loss(
                xt_a2b, lb)
            xt_b2a_gan_loss, xt_b2a_gan_acc, xt_b2a_gan_feat = self.dis_a.calc_gan_loss(
                xt_b2a, la)
            xr_a_gan_loss, xr_a_gan_acc, xr_a_gan_feat = self.dis_a.calc_gan_loss(
                xr_a, la)
            xr_b_gan_loss, xr_b_gan_acc, xr_b_gan_feat = self.dis_b.calc_gan_loss(
                xr_b, lb)
            if hp['mode'] == 'B':
                gan_loss = (xt_a2b_gan_loss + xt_b2a_gan_loss + xr_a_gan_loss +
                            xr_b_gan_loss) * 0.5
            else:
                gan_loss = xt_a2b_gan_loss + xt_b2a_gan_loss
            # feature loss
            _, xb_gan_feat = self.dis_b(xb, lb)
            _, xa_gan_feat = self.dis_a(xa, la)
            xr_feat_loss = recon_criterion(xr_a_gan_feat.mean(3).mean(2),xa_gan_feat.mean(3).mean(2)) + \
                recon_criterion(xr_b_gan_feat.mean(3).mean(2),xb_gan_feat.mean(3).mean(2))
            xt_feat_loss = recon_criterion(xt_b2a_gan_feat.mean(3).mean(2),xa_gan_feat.mean(3).mean(2)) + \
                recon_criterion(xt_a2b_gan_feat.mean(3).mean(2),xb_gan_feat.mean(3).mean(2))
            if hp['mode'] == 'B':
                feat_loss = xt_feat_loss + xr_feat_loss
            else:
                feat_loss = xt_feat_loss
            # reconstruction loss
            xa_rec_loss = recon_criterion(xr_a, xa)
            xb_rec_loss = recon_criterion(xr_b, xb)
            rec_loss = (xa_rec_loss + xb_rec_loss)
            # content loss
            content_a2b_loss = recon_criterion(c_xa, c_xt_a2b)
            content_b2a_loss = recon_criterion(c_xb, c_xt_b2a)
            content_loss = (content_a2b_loss + content_b2a_loss)
            # style loss
            style_a2b_loss = recon_criterion(s_xb, s_xt_a2b)
            style_b2a_loss = recon_criterion(s_xa, s_xt_b2a)
            style_loss = (style_a2b_loss + style_b2a_loss)
            # total loss
            total_loss = hp['gan_w'] * gan_loss + hp['r_w'] * rec_loss + hp[
                'fm_w'] * feat_loss + hp['c_w'] * content_loss + hp[
                    's_w'] * style_loss
            total_loss.backward()
            acc = 0.5 * (xt_a2b_gan_acc + xt_b2a_gan_acc
                         )  # the accuracy of fake image recognition
            return total_loss, gan_loss, feat_loss, rec_loss, content_loss, style_loss, acc
        elif mode == 'dis_update':
            xb.requires_grad_()
            xa.requires_grad_()
            ################# dis_a #################
            dis_a_real_loss, dis_a_real_acc, dis_a_real_resp = self.dis_a.calc_dis_real_loss(
                xa, la)  # real loss
            dis_a_real_loss = hp['gan_w'] * dis_a_real_loss
            dis_a_real_loss.backward(retain_graph=True)
            dis_a_reg_loss = 10 * self.dis_a.calc_grad2(dis_a_real_resp,
                                                        xa)  # reg loss
            dis_a_reg_loss.backward()
            # fake loss
            with torch.no_grad():
                c_xb = self.gen_b.enc_content(xb)
                c_xa = self.gen_a.enc_content(xa)
                s_xa = self.gen_a.enc_class_model(xa)
                xr_a = self.gen_a.decode(c_xa, s_xa)
                xt_b2a = self.gen_a.decode(c_xb, s_xa)
            dis_at_fake_loss, dis_at_fake_acc, dis_at_fake_resp = self.dis_a.calc_dis_fake_loss(
                xt_b2a.detach(), la)
            dis_ar_fake_loss, dis_ar_fake_acc, dis_ar_fake_resp = self.dis_a.calc_dis_fake_loss(
                xr_a.detach(), la)
            if hp['mode'] == 'B':
                dis_a_fake_loss = hp['gan_w'] * (dis_at_fake_loss +
                                                 dis_ar_fake_loss) * 0.5
            else:
                dis_a_fake_loss = hp['gan_w'] * dis_at_fake_loss
            dis_a_fake_loss.backward()
            ################# dis_b #################
            dis_b_real_loss, dis_b_real_acc, dis_b_real_resp = self.dis_b.calc_dis_real_loss(
                xb, lb)  # real loss
            dis_b_real_loss.backward(retain_graph=True)
            dis_b_reg_loss = 10 * self.dis_b.calc_grad2(dis_b_real_resp,
                                                        xb)  #reg loss
            dis_b_reg_loss.backward()
            # fake loss
            with torch.no_grad():
                c_xa = self.gen_a.enc_content(xa)
                c_xb = self.gen_b.enc_content(xb)
                s_xb = self.gen_b.enc_class_model(xb)
                xr_b = self.gen_b.decode(c_xb, s_xb)
                xt_a2b = self.gen_b.decode(c_xa, s_xb)
            dis_bt_fake_loss, dis_bt_fake_acc, dis_bt_fake_resp = self.dis_b.calc_dis_fake_loss(
                xt_a2b.detach(), lb)
            dis_br_fake_loss, dis_br_fake_acc, dis_br_fake_resp = self.dis_b.calc_dis_fake_loss(
                xr_b.detach(), lb)
            if hp['mode'] == 'B':
                dis_b_fake_loss = hp['gan_w'] * (dis_bt_fake_loss +
                                                 dis_br_fake_loss) * 0.5
            else:
                dis_b_fake_loss = hp['gan_w'] * dis_bt_fake_loss
            dis_b_fake_loss.backward()

            real_loss = (dis_a_real_loss + dis_b_real_loss)
            fake_loss = (dis_a_fake_loss + dis_b_fake_loss)
            reg_loss = (dis_a_reg_loss + dis_b_reg_loss)
            total_loss = (dis_a_fake_loss + dis_b_fake_loss + dis_a_real_loss +
                          dis_b_real_loss + dis_a_reg_loss + dis_b_reg_loss)
            acc = 0.25 * (dis_at_fake_acc + dis_bt_fake_acc + dis_a_real_acc +
                          dis_b_real_acc)
            # print("Dis:[fake_loss:%.2f" % fake_loss.item(),"real_loss:%.2f" % real_loss.item(),"reg_loss:%.2f]" % reg_loss.item())
            return total_loss, fake_loss, real_loss, reg_loss, acc
        else:
            assert 0, 'Not support operation'

    def test(self, co_data, cl_data):
        self.eval()
        # self.gen_a.eval()
        # self.gen_b.eval()
        # self.gen_test_a.eval()
        # self.gen_test_b.eval()
        xa = co_data[0].cuda()
        xb = cl_data[0].cuda()

        c_xa = self.gen_test_a.enc_content(xa)
        c_xb = self.gen_test_b.enc_content(xb)
        s_xa = self.gen_test_a.enc_class_model(xa)
        s_xb = self.gen_test_b.enc_class_model(xb)
        xr_a = self.gen_test_a.decode(c_xa, s_xa)
        xr_b = self.gen_test_b.decode(c_xb, s_xb)
        xt_a2b = self.gen_test_b.decode(c_xa, s_xb)
        xt_b2a = self.gen_test_a.decode(c_xb, s_xa)

        self.train()
        return xa, xr_a, xt_b2a, xb, xr_b, xt_a2b

    def translate_k_shot(self, co_data, cl_data, k):
        self.eval()
        xa = co_data[0].cuda()
        xb = cl_data[0].cuda()
        c_xa = self.gen_test_a.enc_content(xa)
        if k == 1:
            s_xb = self.gen_test_b.enc_class_model(xb)
            xt_a2b = self.gen_test_b.decode(c_xa, s_xb)
        else:
            s_xb = self.gen_test_b.enc_class_model(xb)
            s_xb = s_xb.squeeze(-1).permute(1, 2, 0)
            s_xb = torch.nn.functional.avg_pool1d(s_xb, k)
            s_xb = s_xb.permute(2, 0, 1).unsqueeze(-1)
            xt_current = self.gen_test.decode(c_xa, s_xb)
        return xt_current

    def compute_k_style(self, style_batch, k):
        self.eval()
        style_batch = style_batch.cuda()
        s_xb_before = self.gen_test_b.enc_class_model(style_batch)
        s_xb_after = s_xb_before.squeeze(-1).permute(1, 2, 0)
        s_xb_pool = torch.nn.functional.avg_pool1d(s_xb_after, k)
        s_xb = s_xb_pool.permute(2, 0, 1).unsqueeze(-1)
        return s_xb

    def translate_simple(self, content_image, class_code):
        self.eval()
        xa = content_image.cuda()
        s_xb = class_code.cuda()
        c_xa = self.gen_test_a.enc_content(xa)
        xt = self.gen_test_b.decode(c_xa, s_xb)
        return xt
Пример #6
0
class FUNITModel(nn.Module):
    def __init__(self, hp):
        super(FUNITModel, self).__init__()
        self.gen = FewShotGen(hp['gen'])
        self.dis = GPPatchMcResDis(hp['dis'])
        self.gen_test = copy.deepcopy(self.gen)

    def forward(self, co_data, cl_data, hp, mode):

        #debug = Debugger(self.forward.__name__, self.__class__.__name__, PREFIX) #Delete afterwards

        xa = co_data[0].cuda()
        la = co_data[1].cuda()
        xb = cl_data[0].cuda()
        lb = cl_data[1].cuda()
        if mode == 'gen_update':
            c_xa = self.gen.enc_content(xa)
            s_xa = self.gen.enc_class_model(xa)
            s_xb = self.gen.enc_class_model(xb)
            xt = self.gen.decode(c_xa, s_xb)  # translation
            xr = self.gen.decode(c_xa, s_xa)  # reconstruction
            #if (xt.shape[1]!=xa.shape[1]):
            #    print("SHAPE OF INPUT %d AND OF PREDICTION %d AREN'T EQUAL!" % (xa.shape, xt.shape))
            #    xt = F.interpolate(xt, xa.shape[1])
            l_adv_t, gacc_t, xt_gan_feat = self.dis.calc_gen_loss(xt, lb)
            l_adv_r, gacc_r, xr_gan_feat = self.dis.calc_gen_loss(xr, la)
            _, xb_gan_feat = self.dis(xb, lb)
            _, xa_gan_feat = self.dis(xa, la)
            l_c_rec = recon_criterion(
                xr_gan_feat.mean(3).mean(2),
                xa_gan_feat.mean(3).mean(2))
            l_m_rec = recon_criterion(
                xt_gan_feat.mean(3).mean(2),
                xb_gan_feat.mean(3).mean(2))
            l_x_rec = recon_criterion(xr, xa.float())
            l_adv = 0.5 * (l_adv_t + l_adv_r)
            acc = 0.5 * (gacc_t + gacc_r)
            l_total = (hp['gan_w'] * l_adv + hp['r_w'] * l_x_rec + hp['fm_w'] *
                       (l_c_rec + l_m_rec))
            if (GlobalConstants.usingApex):
                with amp.scale_loss(
                        l_total, [self.gen_opt, self.dis_opt]) as scaled_loss:
                    scaled_loss.backward()
            else:
                l_total.backward()
            return l_total, l_adv, l_x_rec, l_c_rec, l_m_rec, acc
        elif mode == 'dis_update':
            xb.requires_grad_()
            l_real_pre, acc_r, resp_r = self.dis.calc_dis_real_loss(xb, lb)
            l_real = hp['gan_w'] * l_real_pre
            if (GlobalConstants.usingApex):
                with amp.scale_loss(
                        l_real, [self.gen_opt, self.dis_opt]) as scaled_loss:
                    scaled_loss.backward(retain_graph=True)
            else:
                l_real.backward(retain_graph=True)
            l_reg_pre = self.dis.calc_grad2(resp_r, xb)
            l_reg = 10 * l_reg_pre
            if (GlobalConstants.usingApex):
                with amp.scale_loss(
                        l_reg, [self.gen_opt, self.dis_opt]) as scaled_loss:
                    scaled_loss.backward()
            else:
                l_reg.backward()
            with torch.no_grad():
                c_xa = self.gen.enc_content(xa)
                s_xb = self.gen.enc_class_model(xb)
                xt = self.gen.decode(c_xa, s_xb)
            l_fake_p, acc_f, resp_f = self.dis.calc_dis_fake_loss(
                xt.detach(), lb)
            l_fake = hp['gan_w'] * l_fake_p
            if (GlobalConstants.usingApex):
                with amp.scale_loss(
                        l_fake, [self.gen_opt, self.dis_opt]) as scaled_loss:
                    scaled_loss.backward()
            else:
                l_fake.backward()
            l_total = l_fake + l_real + l_reg
            acc = 0.5 * (acc_f + acc_r)
            return l_total, l_fake_p, l_real_pre, l_reg_pre, acc
        else:
            assert 0, 'Not support operation'

    def test(self, co_data, cl_data):
        self.eval()
        self.gen.eval()
        self.gen_test.eval()
        xa = co_data[0].cuda()
        xb = cl_data[0].cuda()
        c_xa_current = self.gen.enc_content(xa)
        s_xa_current = self.gen.enc_class_model(xa)
        s_xb_current = self.gen.enc_class_model(xb)
        xt_current = self.gen.decode(c_xa_current, s_xb_current)
        xr_current = self.gen.decode(c_xa_current, s_xa_current)
        c_xa = self.gen_test.enc_content(xa)
        s_xa = self.gen_test.enc_class_model(xa)
        s_xb = self.gen_test.enc_class_model(xb)
        xt = self.gen_test.decode(c_xa, s_xb)
        xr = self.gen_test.decode(c_xa, s_xa)
        self.train()
        return xa, xr_current, xt_current, xb, xr, xt

    def translate_k_shot(self, co_data, cl_data, k):
        self.eval()
        xa = co_data[0].cuda()
        xb = cl_data[0].cuda()
        c_xa_current = self.gen_test.enc_content(xa)
        if k == 1:
            c_xa_current = self.gen_test.enc_content(xa)
            s_xb_current = self.gen_test.enc_class_model(xb)
            xt_current = self.gen_test.decode(c_xa_current, s_xb_current)
        else:
            s_xb_current_before = self.gen_test.enc_class_model(xb)
            s_xb_current_after = s_xb_current_before.squeeze(-1).permute(
                1, 2, 0)
            s_xb_current_pool = torch.nn.functional.avg_pool1d(
                s_xb_current_after, k)
            s_xb_current = s_xb_current_pool.permute(2, 0, 1).unsqueeze(-1)
            xt_current = self.gen_test.decode(c_xa_current, s_xb_current)
        return xt_current

    def compute_k_style(self, style_batch, k):
        self.eval()
        style_batch = style_batch.cuda()
        s_xb_before = self.gen_test.enc_class_model(style_batch)
        s_xb_after = s_xb_before.squeeze(-1).permute(1, 2, 0)
        s_xb_pool = torch.nn.functional.avg_pool1d(s_xb_after, k)
        s_xb = s_xb_pool.permute(2, 0, 1).unsqueeze(-1)
        return s_xb

    def translate_simple(self, content_image, class_code):
        self.eval()
        xa = content_image.cuda()
        s_xb_current = class_code.cuda()
        c_xa_current = self.gen_test.enc_content(xa)
        xt_current = self.gen_test.decode(c_xa_current, s_xb_current)
        return xt_current

    def setOptimizersForApex(self, gen_opt, dis_opt):
        self.gen_opt = gen_opt
        self.dis_opt = dis_opt