class FUNITModel(nn.Module): def __init__(self, hp): super(FUNITModel, self).__init__() self.gen = FewShotGen(hp['gen']) self.dis = GPPatchMcResDis(hp['dis']) self.gen_test = copy.deepcopy(self.gen) def forward(self, co_data, cl_data, hp, mode): xa = co_data[0].cuda() la = co_data[1].cuda() xb = cl_data[0].cuda() lb = cl_data[1].cuda() if mode == 'gen_update': c_xa = self.gen.enc_content(xa) s_xa = self.gen.enc_class_model(xa) s_xb = self.gen.enc_class_model(xb) xt = self.gen.decode(c_xa, s_xb) # translation xr = self.gen.decode(c_xa, s_xa) # reconstruction l_adv_t, gacc_t, xt_gan_feat = self.dis.calc_gen_loss(xt, lb) l_adv_r, gacc_r, xr_gan_feat = self.dis.calc_gen_loss(xr, la) _, xb_gan_feat = self.dis(xb, lb) _, xa_gan_feat = self.dis(xa, la) l_c_rec = recon_criterion( xr_gan_feat.mean(3).mean(2), xa_gan_feat.mean(3).mean(2)) l_m_rec = recon_criterion( xt_gan_feat.mean(3).mean(2), xb_gan_feat.mean(3).mean(2)) l_x_rec = recon_criterion(xr, xa) l_adv = 0.5 * (l_adv_t + l_adv_r) acc = 0.5 * (gacc_t + gacc_r) l_total = (hp['gan_w'] * l_adv + hp['r_w'] * l_x_rec + hp['fm_w'] * (l_c_rec + l_m_rec)) l_total.backward() return l_total, l_adv, l_x_rec, l_c_rec, l_m_rec, acc elif mode == 'dis_update': xb.requires_grad_() l_real_pre, acc_r, resp_r = self.dis.calc_dis_real_loss(xb, lb) l_real = hp['gan_w'] * l_real_pre l_real.backward(retain_graph=True) l_reg_pre = self.dis.calc_grad2(resp_r, xb) l_reg = 10 * l_reg_pre l_reg.backward() with torch.no_grad(): c_xa = self.gen.enc_content(xa) s_xb = self.gen.enc_class_model(xb) xt = self.gen.decode(c_xa, s_xb) l_fake_p, acc_f, resp_f = self.dis.calc_dis_fake_loss( xt.detach(), lb) l_fake = hp['gan_w'] * l_fake_p l_fake.backward() l_total = l_fake + l_real + l_reg acc = 0.5 * (acc_f + acc_r) return l_total, l_fake_p, l_real_pre, l_reg_pre, acc else: assert 0, 'Not support operation' def test(self, co_data, cl_data): self.eval() self.gen.eval() self.gen_test.eval() xa = co_data[0].cuda() xb = cl_data[0].cuda() c_xa_current = self.gen.enc_content(xa) s_xa_current = self.gen.enc_class_model(xa) s_xb_current = self.gen.enc_class_model(xb) xt_current = self.gen.decode(c_xa_current, s_xb_current) xr_current = self.gen.decode(c_xa_current, s_xa_current) c_xa = self.gen_test.enc_content(xa) s_xa = self.gen_test.enc_class_model(xa) s_xb = self.gen_test.enc_class_model(xb) xt = self.gen_test.decode(c_xa, s_xb) xr = self.gen_test.decode(c_xa, s_xa) self.train() return xa, xr_current, xt_current, xb, xr, xt def translate_k_shot(self, co_data, cl_data, k): self.eval() xa = co_data[0].cuda() xb = cl_data[0].cuda() c_xa_current = self.gen_test.enc_content(xa) if k == 1: c_xa_current = self.gen_test.enc_content(xa) s_xb_current = self.gen_test.enc_class_model(xb) xt_current = self.gen_test.decode(c_xa_current, s_xb_current) else: s_xb_current_before = self.gen_test.enc_class_model(xb) s_xb_current_after = s_xb_current_before.squeeze(-1).permute( 1, 2, 0) s_xb_current_pool = torch.nn.functional.avg_pool1d( s_xb_current_after, k) s_xb_current = s_xb_current_pool.permute(2, 0, 1).unsqueeze(-1) xt_current = self.gen_test.decode(c_xa_current, s_xb_current) return xt_current def compute_k_style(self, style_batch, k): self.eval() style_batch = style_batch.cuda() s_xb_before = self.gen_test.enc_class_model(style_batch) s_xb_after = s_xb_before.squeeze(-1).permute(1, 2, 0) s_xb_pool = torch.nn.functional.avg_pool1d(s_xb_after, k) s_xb = s_xb_pool.permute(2, 0, 1).unsqueeze(-1) return s_xb def translate_simple(self, content_image, class_code): self.eval() xa = content_image.cuda() s_xb_current = class_code.cuda() c_xa_current = self.gen_test.enc_content(xa) xt_current = self.gen_test.decode(c_xa_current, s_xb_current) return xt_current
class SEMIT(nn.Module): def __init__(self, hp): super(SEMIT, self).__init__() self.gen = FewShotGen(hp['gen']) self.dis = GPPatchMcResDis(hp['dis']) self.gen_test = copy.deepcopy(self.gen) self.pooling_hf = Avgpool() self.logsoftmax_hf = nn.LogSoftmax(dim=1).cuda() self.softmax_hf = nn.Softmax(dim=1).cuda() self.pooling_lf = Avgpool(kernel_size=2, stride=2) self.logsoftmax_lf = nn.LogSoftmax(dim=1).cuda() self.softmax_lf = nn.Softmax(dim=1).cuda() def forward(self, co_data, cl_data, octave_alpha, hp, mode, constant_octave=0.25): #pdb.set_trace() xa = co_data[0].cuda() la = co_data[1].cuda() xb = cl_data[0].cuda() lb = cl_data[1].cuda() if mode == 'gen_update': c_xa = self.gen.enc_content(xa, alpha_in=octave_alpha, alpha_out=octave_alpha) s_xa = self.gen.enc_class_model(xa, alpha_in=octave_alpha, alpha_out=octave_alpha) s_xb = self.gen.enc_class_model(xb, alpha_in=octave_alpha, alpha_out=octave_alpha) xt = self.gen.decode(c_xa, s_xb, octave_alpha) # translation xr = self.gen.decode(c_xa, s_xa, octave_alpha) # reconstruction l_adv_t, gacc_t, xt_gan_feat = self.dis.calc_gen_loss(xt, lb, constant_octave) l_adv_r, gacc_r, xr_gan_feat = self.dis.calc_gen_loss(xr, la, constant_octave) _, xb_gan_feat = self.dis(xb, lb, alpha_in=constant_octave, alpha_out=constant_octave) _, xa_gan_feat = self.dis(xa, la, alpha_in=constant_octave, alpha_out=constant_octave) # entropy loss l_e = entropy_loss(c_xa, (self.pooling_hf, self.pooling_lf), (self.softmax_hf, self.softmax_lf), (self.logsoftmax_hf, self.logsoftmax_lf)) c_xt = self.gen.enc_content(xt, alpha_in=octave_alpha, alpha_out=octave_alpha) xr_cyc = self.gen.decode(c_xt, s_xa, octave_alpha) l_x_rec_cyc = recon_criterion(xr_cyc, xa) l_c_rec = recon_criterion(xr_gan_feat.mean(3).mean(2), xa_gan_feat.mean(3).mean(2)) l_m_rec = recon_criterion(xt_gan_feat.mean(3).mean(2), xb_gan_feat.mean(3).mean(2)) l_x_rec = recon_criterion(xr, xa) ## rec loss + cycle loss l_x_rec = l_x_rec + 1.*l_x_rec_cyc l_adv = 0.5 * (l_adv_t + l_adv_r) acc = 0.5 * (gacc_t + gacc_r) l_total = (hp['gan_w'] * l_adv + hp['r_w'] * l_x_rec + hp[ 'fm_w'] * (l_c_rec + l_m_rec)) + 0.01 * l_e l_total.backward() return l_total, l_adv, l_x_rec, l_c_rec, l_m_rec, acc elif mode == 'dis_update': xb.requires_grad_() #In Disc I use constant octave: constant_octave = 0.25 l_real_pre, acc_r, resp_r = self.dis.calc_dis_real_loss(xb, lb, constant_octave) l_real = hp['gan_w'] * l_real_pre l_real.backward(retain_graph=True) l_reg_pre = self.dis.calc_grad2(resp_r, xb) l_reg = 10 * l_reg_pre l_reg.backward() with torch.no_grad(): c_xa = self.gen.enc_content(xa, alpha_in=octave_alpha, alpha_out=octave_alpha) s_xb = self.gen.enc_class_model(xb, alpha_in=octave_alpha, alpha_out=octave_alpha) xt = self.gen.decode(c_xa, s_xb, octave_alpha) l_fake_p, acc_f, resp_f = self.dis.calc_dis_fake_loss(xt.detach(), lb, constant_octave) l_fake = hp['gan_w'] * l_fake_p l_fake.backward() l_total = l_fake + l_real + l_reg acc = 0.5 * (acc_f + acc_r) return l_total, l_fake_p, l_real_pre, l_reg_pre, acc else: assert 0, 'Not support operation' def test(self, co_data, cl_data): self.eval() self.gen.eval() self.gen_test.eval() xa = co_data[0].cuda() xb = cl_data[0].cuda() for octave_alpha_value_index in range(11): octave_alpha_value = octave_alpha_value_index / 10. alpha_in, alpha_out = octave_alpha_value, octave_alpha_value c_xa_current = self.gen.enc_content(xa, alpha_in=alpha_in, alpha_out=alpha_out) s_xa_current = self.gen.enc_class_model(xa, alpha_in=alpha_in, alpha_out=alpha_out) s_xb_current = self.gen.enc_class_model(xb, alpha_in=alpha_in, alpha_out=alpha_out) xt_current = self.gen.decode(c_xa_current, s_xb_current, octave_alpha_value) xr_current = self.gen.decode(c_xa_current, s_xa_current, octave_alpha_value) c_xa = self.gen_test.enc_content(xa, alpha_in=alpha_in, alpha_out=alpha_out) s_xa = self.gen_test.enc_class_model(xa, alpha_in=alpha_in, alpha_out=alpha_out) s_xb = self.gen_test.enc_class_model(xb, alpha_in=alpha_in, alpha_out=alpha_out) xt = self.gen_test.decode(c_xa, s_xb, octave_alpha_value) xr = self.gen_test.decode(c_xa, s_xa, octave_alpha_value) if octave_alpha_value_index==0: xt_current_set = [xt_current] xr_current_set = [xr_current] xt_set = [xt] xr_set = [xr] else: xt_current_set.append(xt_current) xr_current_set.append(xr_current) xt_set.append(xt) xr_set.append(xr) self.train() #return xa, xr_current, xt_current, xb, xr, xt return xa, xr_current_set[5], xt_current_set[5], xb, xr_set[5], xt_set[0], xt_set[1], xt_set[2], xt_set[3],xt_set[4], xt_set[5],xt_set[6], xt_set[7],xt_set[8], xt_set[9], xt_set[10] def translate_k_shot(self, co_data, cl_data, k): self.eval() xa = co_data[0].cuda() xb = cl_data[0].cuda() c_xa_current = self.gen_test.enc_content(xa) if k == 1: c_xa_current = self.gen_test.enc_content(xa) s_xb_current = self.gen_test.enc_class_model(xb) xt_current = self.gen_test.decode(c_xa_current, s_xb_current) else: s_xb_current_before = self.gen_test.enc_class_model(xb) s_xb_current_after = s_xb_current_before.squeeze(-1).permute(1, 2, 0) s_xb_current_pool = torch.nn.functional.avg_pool1d( s_xb_current_after, k) s_xb_current = s_xb_current_pool.permute(2, 0, 1).unsqueeze(-1) xt_current = self.gen_test.decode(c_xa_current, s_xb_current) return xt_current def compute_k_style(self, style_batch, k): self.eval() style_batch = style_batch.cuda() s_xb_before = self.gen_test.enc_class_model(style_batch) s_xb_after = s_xb_before.squeeze(-1).permute(1, 2, 0) s_xb_pool = torch.nn.functional.avg_pool1d(s_xb_after, k) s_xb = s_xb_pool.permute(2, 0, 1).unsqueeze(-1) return s_xb def translate_simple(self, content_image, class_code): self.eval() xa = content_image.cuda() s_xb_current = class_code.cuda() c_xa_current = self.gen_test.enc_content(xa) xt_current = self.gen_test.decode(c_xa_current, s_xb_current) return xt_current
class G2GModel(nn.Module): def __init__(self, hp): super(G2GModel, self).__init__() self.gen = FewShotGen(hp['gen']) self.dis = GPPatchMcResDis(hp['dis']) self.gen_test = copy.deepcopy(self.gen) def cross_forward(self, xa, xb, mode): if mode == 'train': gen = self.gen else: gen = self.gen_test # content and class codes image xa (meerkat) xa_cont = gen.enc_content(xa) xa_class = gen.enc_class_model(xa) # content and class codes image xb (dog) xb_cont = gen.enc_content(xb) xb_class = gen.enc_class_model(xb) # mixed images mb = gen.decode(xa_cont, xb_class) # translated dog ma = gen.decode(xb_cont, xa_class) # translated meerkat # reconstruction with itself xa_rec = gen.decode(xa_cont, xa_class) # reconstruction of image a (meerkat) xb_rec = gen.decode(xb_cont, xb_class) # reconstruction of image b (dog) # content and class codes of mixed image mb (translated dog) mb_cont = gen.enc_content(mb) mb_class = gen.enc_class_model(mb) # content and class codes of mixed image ma (translated meerkat) ma_cont = gen.enc_content(ma) ma_class = gen.enc_class_model(ma) # reconstruction after reassembly stage ra = gen.decode(mb_cont, ma_class) # meerkat rb = gen.decode(ma_cont, mb_class) # dog """ Returns (using meerkat dog example): Let meerkat be xa, and dog be xb, then - xa_cont: content code of meerkat - xa_class: class code of meerkat - xb_cont: content code of dog - xb_class: class code of dog - mb: translated dog - ma: translated meerkat - xa_rec: short reconstruction meerkat - xb_rec: short reconstruction dog - mb_cont: content code of translated dog - mb_class: class code of translated dog - ma_cont: content code of translated meerkat - ma_class: class code of translated meerkat - ra: fully reconstructed meerkat - rb: fully reconstructed dog """ return { "xa_cont": xa_cont, "xa_class": xa_class, "xb_cont": xb_cont, "xb_class": xb_class, "mb": mb, "ma": ma, "xa_rec": xa_rec, "xb_rec": xb_rec, "mb_cont": mb_cont, "mb_class": mb_class, "ma_cont": ma_cont, "ma_class": ma_class, "ra": ra, "rb": rb, } def calc_g_loss(self, out, xa, xb, la, lb): # adversarial loss, generator accuracy and features l_adv_mb, gacc_mb, mb_gan_feat = self.dis.calc_gen_loss( out["mb"], lb ) # calc_gen_loss returns loss, accuracy and gan_feat of only first param, i.e. xt l_adv_ma, gacc_ma, ma_gan_feat = self.dis.calc_gen_loss( out["ma"], la ) # calc_gen_loss returns loss, accuracy and gan_feat of only first param, i.e. xt l_adv_xa_rec, gacc_xa_rec, xa_rec_gan_feat = self.dis.calc_gen_loss( out["xa_rec"], la) l_adv_xb_rec, gacc_xb_rec, xb_rec_gan_feat = self.dis.calc_gen_loss( out["xb_rec"], lb) # extracting features for the feature matching loss _, xb_gan_feat = self.dis(xb, lb) _, xa_gan_feat = self.dis(xa, la) # feature matching loss l_fm_xa_rec = recon_criterion( xa_rec_gan_feat.mean(3).mean(2), xa_gan_feat.mean(3).mean(2)) l_fm_xb_rec = recon_criterion( xb_rec_gan_feat.mean(3).mean(2), xb_gan_feat.mean(3).mean(2)) l_fm_rec = 0.5 * (l_fm_xa_rec + l_fm_xb_rec) l_fm_mb = recon_criterion( mb_gan_feat.mean(3).mean(2), xb_gan_feat.mean(3).mean(2)) l_fm_ma = recon_criterion( ma_gan_feat.mean(3).mean(2), xa_gan_feat.mean(3).mean(2)) l_fm_m = 0.5 * (l_fm_ma + l_fm_mb) # short reconstruction loss l_rec_xa = recon_criterion(out["xa_rec"], xa) l_rec_xb = recon_criterion(out["xb_rec"], xb) l_rec = 0.5 * (l_rec_xa + l_rec_xb) # long L1 reconstruction loss l_long_rec_xa = recon_criterion(out["ra"], xa) l_long_rec_xb = recon_criterion(out["rb"], xb) l_long_rec = 0.5 * (l_long_rec_xa + l_long_rec_xb) # long feature matching loss _, gacc_long_fm_xa, ra_gan_feat = self.dis.calc_gen_loss(out["ra"], la) _, gacc_long_fm_xb, rb_gan_feat = self.dis.calc_gen_loss(out["rb"], lb) l_long_fm_xa = recon_criterion( ra_gan_feat.mean(3).mean(2), xa_gan_feat.mean(3).mean(2)) l_long_fm_xb = recon_criterion( rb_gan_feat.mean(3).mean(2), xb_gan_feat.mean(3).mean(2)) l_long_fm = 0.5 * (l_long_fm_xa + l_long_fm_xb) # Feature matching loss in second G2G stage: between the mixed image and the reconstructed image of the same class l_fm_mix_rec_a = recon_criterion( ra_gan_feat.mean(3).mean(2), ma_gan_feat.mean(3).mean( 2)) # compare reconstructed meerkat with mixed meerkat l_fm_mix_rec_b = recon_criterion( rb_gan_feat.mean(3).mean(2), mb_gan_feat.mean(3).mean( 2)) # compare reconstructed dog with mixed dog l_fm_mix_rec = 0.5 * (l_fm_mix_rec_a + l_fm_mix_rec_b) # adversarial loss for l_adv = 0.25 * (l_adv_ma + l_adv_mb + l_adv_xa_rec + l_adv_xb_rec) # accuracy acc = 0.25 * (gacc_ma + gacc_mb + gacc_xa_rec + gacc_xb_rec) # overall loss: adversarial, reconstruction and feature matching reconstruction, feature matching loss and accuracy return l_adv, l_rec, l_fm_rec, l_fm_m, l_long_rec, l_long_fm, l_fm_mix_rec, acc def calc_d_loss(self, xa, xb, la, lb, gan_weight, reg_weight): # calculate discriminator's real loss l_real_pre_a, acc_r_a, resp_r_a = self.dis.calc_dis_real_loss(xa, la) l_real_pre_b, acc_r_b, resp_r_b = self.dis.calc_dis_real_loss(xb, lb) l_real_pre = 0.5 * (l_real_pre_a + l_real_pre_b) l_real = gan_weight * l_real_pre l_real.backward(retain_graph=True) # real gradient penalty regularization proposed by Mescheder et al. l_reg_pre_a = self.dis.calc_grad2(resp_r_a, xa) l_reg_pre_b = self.dis.calc_grad2(resp_r_b, xb) l_reg_pre = 0.5 * (l_reg_pre_a + l_reg_pre_b) l_reg = reg_weight * l_reg_pre l_reg.backward() # generate images for the discriminator to classify with torch.no_grad(): xa_cont = self.gen.enc_content(xa) # meerkat xb_cont = self.gen.enc_content(xb) # dog xa_class = self.gen.enc_class_model(xa) xb_class = self.gen.enc_class_model(xb) mb = self.gen.decode(xa_cont, xb_class) # dog ma = self.gen.decode(xb_cont, xa_class) # meerkat # calculate discriminator's fake loss l_fake_pre_a, acc_f_a, resp_f_a = self.dis.calc_dis_fake_loss( ma.detach(), lb) # meerkat l_fake_pre_b, acc_f_b, resp_f_b = self.dis.calc_dis_fake_loss( mb.detach(), la) # dog l_fake_pre = 0.5 * (l_fake_pre_a + l_fake_pre_b) l_fake = gan_weight * l_fake_pre l_fake.backward() acc = 0.25 * (acc_f_a + acc_f_b + acc_r_a + acc_r_b) return l_fake_pre, l_real_pre, l_reg_pre, acc def forward(self, co_data, cl_data, hp, mode): """ Params: - co_data: content data with one content image and one content image label - cl_data: class data with one class image and one class image label - hp: hyperparameters, more specifically weights for the losses - mode: discriminator or generator update step Returns: Generator: - l_total: overall generator loss - l_adv: adversarial loss of generator - l_x_rec: reconstruction loss - l_c_rec: feature matching loss same image - l_m_rec: feature matching loss translated image - acc: generator accuracy Discriminator: - l_total: overall discriminator loss - l_fake_p: discriminator fake loss - l_real_pre: discriminator real loss - l_reg_pre: real gradient penalty regularization loss term - acc: discriminator accuracy """ xa = co_data[0].cuda() la = co_data[1].cuda() xb = cl_data[0].cuda() lb = cl_data[1].cuda() if mode == 'gen_update': # forward pass out = self.cross_forward(xa, xb, 'train') l_adv, l_rec, l_fm_rec, l_fm_m, l_long_rec, l_long_fm, l_fm_mix_rec, acc = self.calc_g_loss( out, xa, xb, la, lb) # overall loss: adversarial, reconstruction and feature matching loss l_total = (hp['gan_w'] * l_adv + hp['r_w'] * l_rec + hp['fm_rec_w'] * l_fm_rec + hp['fm_w'] * l_fm_m + hp['rl_w'] * l_long_rec + hp['fml_w'] * l_long_fm + hp['fml_mix_rec_w'] * l_fm_mix_rec) l_total.backward() return l_total, l_adv, l_rec, l_fm_rec, l_fm_m, l_long_rec, l_long_fm, l_fm_mix_rec, acc elif mode == 'dis_update': # for the gradient penalty regularization xa.requires_grad_() xb.requires_grad_() l_fake, l_real, l_reg, acc = self.calc_d_loss( xa, xb, la, lb, hp['gan_w'], hp['reg_w']) # overall loss: fake, real and regularization loss term l_total = hp['gan_w'] * (l_fake + l_real) + l_reg return l_total, l_fake, l_real, l_reg, acc else: assert 0, 'Not support operation' def test(self, co_data, cl_data): """ Params: - co_data: content data with one content image and one content image label - cl_data: class data with one class image and one class image label Returns: - xa: original image meerkat - xb: original image dog - mb: mixed image dog in meerkat position - ma: mixed image meerkat in dog position - ra: reconstructed image meerkat - rb: reconstructed image dog """ self.eval() self.gen.eval() self.gen_test.eval() xa = co_data[0].cuda() # meerkat xb = cl_data[0].cuda() # dog out = self.cross_forward(xa, xb, 'test') self.train() return xa, xb, out['mb'], out['ma'], out['ra'], out['rb'] def translate_k_shot(self, co_data, cl_data, k): """ Params: - co_data: content data with one content image and one content image label - cl_data: class data with one class image and one class image label - k: number of shots to generate a translated image Returns: - xt_current: translated image at current training state of model - rec_current: reassembled image at current training state of model """ self.eval() # for training on GPU xa = co_data[0].cuda() xb = cl_data[0].cuda() xa_cont_current = self.gen_test.enc_content(xa) # perform translation for k shots if k == 1: xa_cont_current = self.gen_test.enc_content(xa) xb_class_current = self.gen_test.enc_class_model(xb) xt_current = self.gen_test.decode(xa_cont_current, xb_class_current) else: xb_class_current_before = self.gen_test.enc_class_model(xb) xb_class_current_after = xb_class_current_before.squeeze( -1).permute(1, 2, 0) xb_class_current_pool = torch.nn.functional.avg_pool1d( xb_class_current_after, k) xb_class_current = xb_class_current_pool.permute(2, 0, 1).unsqueeze(-1) xt_current = self.gen_test.decode(xa_cont_current, xb_class_current) return xt_current def compute_k_style(self, style_batch, k): self.eval() style_batch = style_batch.cuda() xb_class_before = self.gen_test.enc_class_model(style_batch) xb_class_after = xb_class_before.squeeze(-1).permute(1, 2, 0) xb_class_pool = torch.nn.functional.avg_pool1d(xb_class_after, k) xb_class = xb_class_pool.permute(2, 0, 1).unsqueeze(-1) return xb_class def translate_simple(self, content_image, class_code): self.eval() xa = content_image.cuda() xb_class_current = class_code.cuda() xa_cont_current = self.gen_test.enc_content(xa) xt_current = self.gen_test.decode(xa_cont_current, xb_class_current) return xt_current def translate_cross(self, content_image_a, content_image_b): """ Params: - content_image_a - content_image_b - class_code_a - class_code_b Returns: - out: dictionary with all intermediate images and codes """ self.eval() xa = content_image_a.cuda() xb = content_image_b.cuda() # forward pass out = self.cross_forward(xa, xb, 'test') return out
class FUNITModel(nn.Module): def __init__(self, hp): super(FUNITModel, self).__init__() self.gen = FewShotGen(hp['gen']) self.dis = GPPatchMcResDis(hp['dis']) self.gen_test = copy.deepcopy(self.gen) def forward(self, co_data, cl_data, hp, mode): """ Params: - co_data: content data with one content image and one content image label - cl_data: class data with one class image and one class image label - hp: hyperparameters, more specifically weights for the losses - mode: discriminator or generator update step Returns: Generator: - l_total: overall generator loss - l_adv: adversarial loss of generator - l_x_rec: reconstruction loss - l_c_rec: feature matching loss same image - l_m_rec: feature matching loss translated image - acc: generator accuracy Discriminator: - l_total: overall discriminator loss - l_fake_p: discriminator fake loss - l_real_pre: discriminator real loss - l_reg_pre: real gradient penalty regularization loss term - acc: discriminator accuracy """ xa = co_data[0].cuda() la = co_data[1].cuda() xb = cl_data[0].cuda() lb = cl_data[1].cuda() if mode == 'gen_update': # forward pass c_xa = self.gen.enc_content(xa) # xa is content code of meerkat s_xa = self.gen.enc_class_model(xa) # class code of meerkat s_xb = self.gen.enc_class_model(xb) # xb = dog xt = self.gen.decode(c_xa, s_xb) # translated dog xr = self.gen.decode(c_xa, s_xa) # reconstructed meerkat # adversarial loss, generator accuracy and features # calc_gen_loss returns loss, accuracy and gan_feat of only first param, i.e. xt l_adv_t, gacc_t, xt_gan_feat = self.dis.calc_gen_loss(xt, lb) # l_adv_r, gacc_r, xr_gan_feat = self.dis.calc_gen_loss(xr, la) # extracting features for the feature matching loss _, xb_gan_feat = self.dis( xb, lb) # xb_gan_feat are the features of the originl meerkat _, xa_gan_feat = self.dis(xa, la) # feature matching loss l_c_rec = recon_criterion( xr_gan_feat.mean(3).mean(2), xa_gan_feat.mean(3).mean(2)) l_m_rec = recon_criterion( xt_gan_feat.mean(3).mean(2), xb_gan_feat.mean(3).mean(2)) # short reconstruction loss l_x_rec = recon_criterion(xr, xa) # adversarial loss for l_adv = 0.5 * (l_adv_t + l_adv_r) # accuracy acc = 0.5 * (gacc_t + gacc_r) # overall loss: adversarial, reconstruction and feature matching loss l_total = (hp['gan_w'] * l_adv + hp['r_w'] * l_x_rec + hp['fm_w'] * (l_c_rec + l_m_rec)) l_total.backward() return l_total, l_adv, l_x_rec, l_c_rec, l_m_rec, acc elif mode == 'dis_update': xb.requires_grad_() # calculate discriminator's real loss l_real_pre, acc_r, resp_r = self.dis.calc_dis_real_loss(xb, lb) l_real = hp['gan_w'] * l_real_pre l_real.backward(retain_graph=True) # real gradient penalty regularization proposed by Mescheder et al. l_reg_pre = self.dis.calc_grad2(resp_r, xb) l_reg = 10 * l_reg_pre l_reg.backward() # generate images for the discriminator to classify with torch.no_grad(): c_xa = self.gen.enc_content(xa) s_xb = self.gen.enc_class_model(xb) xt = self.gen.decode(c_xa, s_xb) # calculate discriminator's fake loss l_fake_p, acc_f, resp_f = self.dis.calc_dis_fake_loss( xt.detach(), lb) l_fake = hp['gan_w'] * l_fake_p l_fake.backward() l_total = l_fake + l_real + l_reg acc = 0.5 * (acc_f + acc_r) return l_total, l_fake_p, l_real_pre, l_reg_pre, acc else: assert 0, 'Not support operation' def test(self, co_data, cl_data): """ Params: - co_data: content data with one content image and one content image label - cl_data: class data with one class image and one class image label """ self.eval() self.gen.eval() self.gen_test.eval() xa = co_data[0].cuda() xb = cl_data[0].cuda() c_xa_current = self.gen.enc_content(xa) s_xa_current = self.gen.enc_class_model(xa) s_xb_current = self.gen.enc_class_model(xb) xt_current = self.gen.decode(c_xa_current, s_xb_current) xr_current = self.gen.decode(c_xa_current, s_xa_current) c_xa = self.gen_test.enc_content(xa) s_xa = self.gen_test.enc_class_model(xa) s_xb = self.gen_test.enc_class_model(xb) xt = self.gen_test.decode(c_xa, s_xb) xr = self.gen_test.decode(c_xa, s_xa) self.train() """ Returns: - xa: original image of domain A - xr : reconstruction of two same images - xt : translated mixed image - xb: original image of domain B - xr: test reconstruction - xt: test translation """ return xa, xr_current, xt_current, xb, xr, xt def translate_k_shot(self, co_data, cl_data, k): """ Params: - co_data: content data with one content image and one content image label - cl_data: class data with one class image and one class image label - k: number of shots to generate a translated image """ self.eval() # for training on GPU xa = co_data[0].cuda() xb = cl_data[0].cuda() c_xa_current = self.gen_test.enc_content(xa) # perform translation for k shots if k == 1: c_xa_current = self.gen_test.enc_content(xa) s_xb_current = self.gen_test.enc_class_model(xb) xt_current = self.gen_test.decode(c_xa_current, s_xb_current) else: s_xb_current_before = self.gen_test.enc_class_model(xb) s_xb_current_after = s_xb_current_before.squeeze(-1).permute( 1, 2, 0) s_xb_current_pool = torch.nn.functional.avg_pool1d( s_xb_current_after, k) s_xb_current = s_xb_current_pool.permute(2, 0, 1).unsqueeze(-1) xt_current = self.gen_test.decode(c_xa_current, s_xb_current) return xt_current def compute_k_style(self, style_batch, k): self.eval() style_batch = style_batch.cuda() s_xb_before = self.gen_test.enc_class_model(style_batch) s_xb_after = s_xb_before.squeeze(-1).permute(1, 2, 0) s_xb_pool = torch.nn.functional.avg_pool1d(s_xb_after, k) s_xb = s_xb_pool.permute(2, 0, 1).unsqueeze(-1) return s_xb def translate_simple(self, content_image, class_code): self.eval() xa = content_image.cuda() s_xb_current = class_code.cuda() c_xa_current = self.gen_test.enc_content(xa) xt_current = self.gen_test.decode(c_xa_current, s_xb_current) return xt_current
class FUNITModel(nn.Module): def __init__(self, hp): super(FUNITModel, self).__init__() self.gen_a = FewShotGen(hp['gen_a']) # human domain Generator self.gen_b = FewShotGen(hp['gen_b']) # anime domain Generator self.dis_a = GPPatchMcResDis(hp['dis_a']) # human domain Discriminator self.dis_b = GPPatchMcResDis(hp['dis_b']) # anime domain Discriminator self.gen_test_a = copy.deepcopy(self.gen_a) self.gen_test_b = copy.deepcopy(self.gen_b) def forward(self, co_data, cl_data, hp, mode): xa = co_data[0].cuda() la = co_data[1].cuda() xb = cl_data[0].cuda() lb = cl_data[1].cuda() if mode == 'gen_update': # get the content and style code for human domain and anime domain c_xa = self.gen_a.enc_content(xa) c_xb = self.gen_b.enc_content(xb) s_xa = self.gen_a.enc_class_model(xa) s_xb = self.gen_b.enc_class_model(xb) # reconstruction xr_a = self.gen_a.decode(c_xa, s_xa) xr_b = self.gen_b.decode(c_xb, s_xb) # translation xt_a2b = self.gen_b.decode(c_xa, s_xb) xt_b2a = self.gen_a.decode(c_xb, s_xa) # recode c_xt_a2b = self.gen_b.enc_content(xt_a2b) c_xt_b2a = self.gen_a.enc_content(xt_b2a) s_xt_a2b = self.gen_b.enc_class_model(xt_a2b) s_xt_b2a = self.gen_a.enc_class_model(xt_b2a) ############ caculate loss ############ # gan loss xt_a2b_gan_loss, xt_a2b_gan_acc, xt_a2b_gan_feat = self.dis_b.calc_gan_loss( xt_a2b, lb) xt_b2a_gan_loss, xt_b2a_gan_acc, xt_b2a_gan_feat = self.dis_a.calc_gan_loss( xt_b2a, la) xr_a_gan_loss, xr_a_gan_acc, xr_a_gan_feat = self.dis_a.calc_gan_loss( xr_a, la) xr_b_gan_loss, xr_b_gan_acc, xr_b_gan_feat = self.dis_b.calc_gan_loss( xr_b, lb) if hp['mode'] == 'B': gan_loss = (xt_a2b_gan_loss + xt_b2a_gan_loss + xr_a_gan_loss + xr_b_gan_loss) * 0.5 else: gan_loss = xt_a2b_gan_loss + xt_b2a_gan_loss # feature loss _, xb_gan_feat = self.dis_b(xb, lb) _, xa_gan_feat = self.dis_a(xa, la) xr_feat_loss = recon_criterion(xr_a_gan_feat.mean(3).mean(2),xa_gan_feat.mean(3).mean(2)) + \ recon_criterion(xr_b_gan_feat.mean(3).mean(2),xb_gan_feat.mean(3).mean(2)) xt_feat_loss = recon_criterion(xt_b2a_gan_feat.mean(3).mean(2),xa_gan_feat.mean(3).mean(2)) + \ recon_criterion(xt_a2b_gan_feat.mean(3).mean(2),xb_gan_feat.mean(3).mean(2)) if hp['mode'] == 'B': feat_loss = xt_feat_loss + xr_feat_loss else: feat_loss = xt_feat_loss # reconstruction loss xa_rec_loss = recon_criterion(xr_a, xa) xb_rec_loss = recon_criterion(xr_b, xb) rec_loss = (xa_rec_loss + xb_rec_loss) # content loss content_a2b_loss = recon_criterion(c_xa, c_xt_a2b) content_b2a_loss = recon_criterion(c_xb, c_xt_b2a) content_loss = (content_a2b_loss + content_b2a_loss) # style loss style_a2b_loss = recon_criterion(s_xb, s_xt_a2b) style_b2a_loss = recon_criterion(s_xa, s_xt_b2a) style_loss = (style_a2b_loss + style_b2a_loss) # total loss total_loss = hp['gan_w'] * gan_loss + hp['r_w'] * rec_loss + hp[ 'fm_w'] * feat_loss + hp['c_w'] * content_loss + hp[ 's_w'] * style_loss total_loss.backward() acc = 0.5 * (xt_a2b_gan_acc + xt_b2a_gan_acc ) # the accuracy of fake image recognition return total_loss, gan_loss, feat_loss, rec_loss, content_loss, style_loss, acc elif mode == 'dis_update': xb.requires_grad_() xa.requires_grad_() ################# dis_a ################# dis_a_real_loss, dis_a_real_acc, dis_a_real_resp = self.dis_a.calc_dis_real_loss( xa, la) # real loss dis_a_real_loss = hp['gan_w'] * dis_a_real_loss dis_a_real_loss.backward(retain_graph=True) dis_a_reg_loss = 10 * self.dis_a.calc_grad2(dis_a_real_resp, xa) # reg loss dis_a_reg_loss.backward() # fake loss with torch.no_grad(): c_xb = self.gen_b.enc_content(xb) c_xa = self.gen_a.enc_content(xa) s_xa = self.gen_a.enc_class_model(xa) xr_a = self.gen_a.decode(c_xa, s_xa) xt_b2a = self.gen_a.decode(c_xb, s_xa) dis_at_fake_loss, dis_at_fake_acc, dis_at_fake_resp = self.dis_a.calc_dis_fake_loss( xt_b2a.detach(), la) dis_ar_fake_loss, dis_ar_fake_acc, dis_ar_fake_resp = self.dis_a.calc_dis_fake_loss( xr_a.detach(), la) if hp['mode'] == 'B': dis_a_fake_loss = hp['gan_w'] * (dis_at_fake_loss + dis_ar_fake_loss) * 0.5 else: dis_a_fake_loss = hp['gan_w'] * dis_at_fake_loss dis_a_fake_loss.backward() ################# dis_b ################# dis_b_real_loss, dis_b_real_acc, dis_b_real_resp = self.dis_b.calc_dis_real_loss( xb, lb) # real loss dis_b_real_loss.backward(retain_graph=True) dis_b_reg_loss = 10 * self.dis_b.calc_grad2(dis_b_real_resp, xb) #reg loss dis_b_reg_loss.backward() # fake loss with torch.no_grad(): c_xa = self.gen_a.enc_content(xa) c_xb = self.gen_b.enc_content(xb) s_xb = self.gen_b.enc_class_model(xb) xr_b = self.gen_b.decode(c_xb, s_xb) xt_a2b = self.gen_b.decode(c_xa, s_xb) dis_bt_fake_loss, dis_bt_fake_acc, dis_bt_fake_resp = self.dis_b.calc_dis_fake_loss( xt_a2b.detach(), lb) dis_br_fake_loss, dis_br_fake_acc, dis_br_fake_resp = self.dis_b.calc_dis_fake_loss( xr_b.detach(), lb) if hp['mode'] == 'B': dis_b_fake_loss = hp['gan_w'] * (dis_bt_fake_loss + dis_br_fake_loss) * 0.5 else: dis_b_fake_loss = hp['gan_w'] * dis_bt_fake_loss dis_b_fake_loss.backward() real_loss = (dis_a_real_loss + dis_b_real_loss) fake_loss = (dis_a_fake_loss + dis_b_fake_loss) reg_loss = (dis_a_reg_loss + dis_b_reg_loss) total_loss = (dis_a_fake_loss + dis_b_fake_loss + dis_a_real_loss + dis_b_real_loss + dis_a_reg_loss + dis_b_reg_loss) acc = 0.25 * (dis_at_fake_acc + dis_bt_fake_acc + dis_a_real_acc + dis_b_real_acc) # print("Dis:[fake_loss:%.2f" % fake_loss.item(),"real_loss:%.2f" % real_loss.item(),"reg_loss:%.2f]" % reg_loss.item()) return total_loss, fake_loss, real_loss, reg_loss, acc else: assert 0, 'Not support operation' def test(self, co_data, cl_data): self.eval() # self.gen_a.eval() # self.gen_b.eval() # self.gen_test_a.eval() # self.gen_test_b.eval() xa = co_data[0].cuda() xb = cl_data[0].cuda() c_xa = self.gen_test_a.enc_content(xa) c_xb = self.gen_test_b.enc_content(xb) s_xa = self.gen_test_a.enc_class_model(xa) s_xb = self.gen_test_b.enc_class_model(xb) xr_a = self.gen_test_a.decode(c_xa, s_xa) xr_b = self.gen_test_b.decode(c_xb, s_xb) xt_a2b = self.gen_test_b.decode(c_xa, s_xb) xt_b2a = self.gen_test_a.decode(c_xb, s_xa) self.train() return xa, xr_a, xt_b2a, xb, xr_b, xt_a2b def translate_k_shot(self, co_data, cl_data, k): self.eval() xa = co_data[0].cuda() xb = cl_data[0].cuda() c_xa = self.gen_test_a.enc_content(xa) if k == 1: s_xb = self.gen_test_b.enc_class_model(xb) xt_a2b = self.gen_test_b.decode(c_xa, s_xb) else: s_xb = self.gen_test_b.enc_class_model(xb) s_xb = s_xb.squeeze(-1).permute(1, 2, 0) s_xb = torch.nn.functional.avg_pool1d(s_xb, k) s_xb = s_xb.permute(2, 0, 1).unsqueeze(-1) xt_current = self.gen_test.decode(c_xa, s_xb) return xt_current def compute_k_style(self, style_batch, k): self.eval() style_batch = style_batch.cuda() s_xb_before = self.gen_test_b.enc_class_model(style_batch) s_xb_after = s_xb_before.squeeze(-1).permute(1, 2, 0) s_xb_pool = torch.nn.functional.avg_pool1d(s_xb_after, k) s_xb = s_xb_pool.permute(2, 0, 1).unsqueeze(-1) return s_xb def translate_simple(self, content_image, class_code): self.eval() xa = content_image.cuda() s_xb = class_code.cuda() c_xa = self.gen_test_a.enc_content(xa) xt = self.gen_test_b.decode(c_xa, s_xb) return xt
class FUNITModel(nn.Module): def __init__(self, hp): super(FUNITModel, self).__init__() self.gen = FewShotGen(hp['gen']) self.dis = GPPatchMcResDis(hp['dis']) self.gen_test = copy.deepcopy(self.gen) def forward(self, co_data, cl_data, hp, mode): #debug = Debugger(self.forward.__name__, self.__class__.__name__, PREFIX) #Delete afterwards xa = co_data[0].cuda() la = co_data[1].cuda() xb = cl_data[0].cuda() lb = cl_data[1].cuda() if mode == 'gen_update': c_xa = self.gen.enc_content(xa) s_xa = self.gen.enc_class_model(xa) s_xb = self.gen.enc_class_model(xb) xt = self.gen.decode(c_xa, s_xb) # translation xr = self.gen.decode(c_xa, s_xa) # reconstruction #if (xt.shape[1]!=xa.shape[1]): # print("SHAPE OF INPUT %d AND OF PREDICTION %d AREN'T EQUAL!" % (xa.shape, xt.shape)) # xt = F.interpolate(xt, xa.shape[1]) l_adv_t, gacc_t, xt_gan_feat = self.dis.calc_gen_loss(xt, lb) l_adv_r, gacc_r, xr_gan_feat = self.dis.calc_gen_loss(xr, la) _, xb_gan_feat = self.dis(xb, lb) _, xa_gan_feat = self.dis(xa, la) l_c_rec = recon_criterion( xr_gan_feat.mean(3).mean(2), xa_gan_feat.mean(3).mean(2)) l_m_rec = recon_criterion( xt_gan_feat.mean(3).mean(2), xb_gan_feat.mean(3).mean(2)) l_x_rec = recon_criterion(xr, xa.float()) l_adv = 0.5 * (l_adv_t + l_adv_r) acc = 0.5 * (gacc_t + gacc_r) l_total = (hp['gan_w'] * l_adv + hp['r_w'] * l_x_rec + hp['fm_w'] * (l_c_rec + l_m_rec)) if (GlobalConstants.usingApex): with amp.scale_loss( l_total, [self.gen_opt, self.dis_opt]) as scaled_loss: scaled_loss.backward() else: l_total.backward() return l_total, l_adv, l_x_rec, l_c_rec, l_m_rec, acc elif mode == 'dis_update': xb.requires_grad_() l_real_pre, acc_r, resp_r = self.dis.calc_dis_real_loss(xb, lb) l_real = hp['gan_w'] * l_real_pre if (GlobalConstants.usingApex): with amp.scale_loss( l_real, [self.gen_opt, self.dis_opt]) as scaled_loss: scaled_loss.backward(retain_graph=True) else: l_real.backward(retain_graph=True) l_reg_pre = self.dis.calc_grad2(resp_r, xb) l_reg = 10 * l_reg_pre if (GlobalConstants.usingApex): with amp.scale_loss( l_reg, [self.gen_opt, self.dis_opt]) as scaled_loss: scaled_loss.backward() else: l_reg.backward() with torch.no_grad(): c_xa = self.gen.enc_content(xa) s_xb = self.gen.enc_class_model(xb) xt = self.gen.decode(c_xa, s_xb) l_fake_p, acc_f, resp_f = self.dis.calc_dis_fake_loss( xt.detach(), lb) l_fake = hp['gan_w'] * l_fake_p if (GlobalConstants.usingApex): with amp.scale_loss( l_fake, [self.gen_opt, self.dis_opt]) as scaled_loss: scaled_loss.backward() else: l_fake.backward() l_total = l_fake + l_real + l_reg acc = 0.5 * (acc_f + acc_r) return l_total, l_fake_p, l_real_pre, l_reg_pre, acc else: assert 0, 'Not support operation' def test(self, co_data, cl_data): self.eval() self.gen.eval() self.gen_test.eval() xa = co_data[0].cuda() xb = cl_data[0].cuda() c_xa_current = self.gen.enc_content(xa) s_xa_current = self.gen.enc_class_model(xa) s_xb_current = self.gen.enc_class_model(xb) xt_current = self.gen.decode(c_xa_current, s_xb_current) xr_current = self.gen.decode(c_xa_current, s_xa_current) c_xa = self.gen_test.enc_content(xa) s_xa = self.gen_test.enc_class_model(xa) s_xb = self.gen_test.enc_class_model(xb) xt = self.gen_test.decode(c_xa, s_xb) xr = self.gen_test.decode(c_xa, s_xa) self.train() return xa, xr_current, xt_current, xb, xr, xt def translate_k_shot(self, co_data, cl_data, k): self.eval() xa = co_data[0].cuda() xb = cl_data[0].cuda() c_xa_current = self.gen_test.enc_content(xa) if k == 1: c_xa_current = self.gen_test.enc_content(xa) s_xb_current = self.gen_test.enc_class_model(xb) xt_current = self.gen_test.decode(c_xa_current, s_xb_current) else: s_xb_current_before = self.gen_test.enc_class_model(xb) s_xb_current_after = s_xb_current_before.squeeze(-1).permute( 1, 2, 0) s_xb_current_pool = torch.nn.functional.avg_pool1d( s_xb_current_after, k) s_xb_current = s_xb_current_pool.permute(2, 0, 1).unsqueeze(-1) xt_current = self.gen_test.decode(c_xa_current, s_xb_current) return xt_current def compute_k_style(self, style_batch, k): self.eval() style_batch = style_batch.cuda() s_xb_before = self.gen_test.enc_class_model(style_batch) s_xb_after = s_xb_before.squeeze(-1).permute(1, 2, 0) s_xb_pool = torch.nn.functional.avg_pool1d(s_xb_after, k) s_xb = s_xb_pool.permute(2, 0, 1).unsqueeze(-1) return s_xb def translate_simple(self, content_image, class_code): self.eval() xa = content_image.cuda() s_xb_current = class_code.cuda() c_xa_current = self.gen_test.enc_content(xa) xt_current = self.gen_test.decode(c_xa_current, s_xb_current) return xt_current def setOptimizersForApex(self, gen_opt, dis_opt): self.gen_opt = gen_opt self.dis_opt = dis_opt