def forward(self, input_img=None, input_text=None): results = dict(); latents = self.inference(input_img=input_img, input_text=input_text); results['latents'] = latents; mus = latents['mus'] logvars = latents['logvars'] weights = latents['weights'] if input_img is not None and input_text is not None: div = self.calc_joint_divergence(mus, logvars, weights); for k,key in enumerate(div.keys()): results[key] = div[key]; results['group_distr'] = latents['joint']; class_embeddings = utils.reparameterize(latents['joint'][0], latents['joint'][1]) if self.flags.factorized_representation: if input_img is not None: [m1_s_mu, m1_s_logvar] = latents['img_celeba'][:2]; m1_style_latent_embeddings = utils.reparameterize(mu=m1_s_mu, logvar=m1_s_logvar) if input_text is not None: [m2_s_mu, m2_s_logvar] = latents['text'][:2]; m2_style_latent_embeddings = utils.reparameterize(mu=m2_s_mu, logvar=m2_s_logvar) else: m1_style_latent_embeddings = None; m2_style_latent_embeddings = None; m1_rec = None; m2_rec = None; if input_img is not None: m1_rec = self.lhood_celeba(*self.decoder_img(m1_style_latent_embeddings, class_embeddings)); if input_text is not None: m2_rec = self.lhood_text(*self.decoder_text(m2_style_latent_embeddings, class_embeddings)); results['rec'] = {'img_celeba': m1_rec, 'text': m2_rec}; return results;
def cond_generation(self, latent_distributions): if 'img_celeba' in latent_distributions: [m1_mu, m1_logvar] = latent_distributions['img_celeba']; content_cond_m1 = utils.reparameterize(mu=m1_mu, logvar=m1_logvar); num_samples = m1_mu.shape[0]; if 'text' in latent_distributions: [m2_mu, m2_logvar] = latent_distributions['text']; content_cond_m2 = utils.reparameterize(mu=m2_mu, logvar=m2_logvar); num_samples = m2_mu.shape[0]; if self.flags.factorized_representation: random_style_m1 = torch.randn(num_samples, self.flags.style_m1_dim); random_style_m2 = torch.randn(num_samples, self.flags.style_m2_dim); random_style_m1 = random_style_m1.to(self.flags.device) random_style_m2 = random_style_m2.to(self.flags.device) else: random_style_m1 = None; random_style_m2 = None; style_latents = {'img_celeba': random_style_m1, 'text': random_style_m2}; cond_gen_samples = dict(); if 'img_celeba' in latent_distributions: latents_mnist = {'content': content_cond_m1, 'style': style_latents} cond_gen_samples['img_celeba'] = self.generate_from_latents(latents_mnist); if 'text' in latent_distributions: latents_svhn = {'content': content_cond_m2, 'style': style_latents} cond_gen_samples['text'] = self.generate_from_latents(latents_svhn); return cond_gen_samples;
def cond_generation_2a(self, latent_distribution_pairs, num_samples=None): if num_samples is None: num_samples = self.flags.batch_size mu0 = torch.zeros(1, num_samples, self.flags.class_dim) logvar0 = torch.zeros(1, num_samples, self.flags.class_dim) mu0 = mu0.to(self.flags.device) logvar0 = logvar0.to(self.flags.device) style_latents = self.get_random_styles(num_samples) cond_gen_2a = dict() for p, pair in enumerate(latent_distribution_pairs.keys()): ld_pair = latent_distribution_pairs[pair] mu_list = [mu0] logvar_list = [logvar0] for k, key in enumerate(ld_pair['latents'].keys()): mu_list.append(ld_pair['latents'][key][0].unsqueeze(0)) logvar_list.append(ld_pair['latents'][key][1].unsqueeze(0)) mus = torch.cat(mu_list, dim=0) logvars = torch.cat(logvar_list, dim=0) weights_pair = ld_pair['weights'] weights_pair.insert(0, self.weights[0]) weights_pair = utils.reweight_weights(torch.Tensor(weights_pair)) mu_joint, logvar_joint = self.modality_fusion( mus, logvars, weights_pair) #mu_joint, logvar_joint = poe(mus, logvars); c_emb = utils.reparameterize(mu_joint, logvar_joint) l_2a = { 'content': c_emb, 'style': style_latents } cond_gen_2a[pair] = self.generate_from_latents(l_2a) return cond_gen_2a
def get_latent_samples(flags, latents, mod_names): l_c = latents['content'] l_s = latents['style'] c_emb = utils.reparameterize(l_c[0], l_c[1]) styles = dict() c = {'mu': l_c[0], 'logvar': l_c[1], 'z': c_emb} if flags.factorized_representation: for k, key in enumerate(l_s.keys()): s_emb = utils.reparameterize(l_s[key][0], l_s[key][1]) s = {'mu': l_s[key][0], 'logvar': l_s[key][1], 'z': s_emb} styles[key] = s else: for k, key in enumerate(mod_names): styles[key] = None emb = {'content': c, 'style': styles} return emb
def forward(self, input_mnist=None, input_svhn=None, input_text=None): latents = self.inference(input_mnist, input_svhn, input_text) results = dict() results['latents'] = latents results['group_distr'] = latents['joint'] class_embeddings = utils.reparameterize(latents['joint'][0], latents['joint'][1]) div = self.calc_joint_divergence(latents['mus'], latents['logvars'], latents['weights']) for k, key in enumerate(div.keys()): results[key] = div[key] results_rec = dict() if input_mnist is not None: m1_s_mu, m1_s_logvar = latents['img_mnist'][:2] if self.flags.factorized_representation: m1_s_embeddings = utils.reparameterize(mu=m1_s_mu, logvar=m1_s_logvar) else: m1_s_embeddings = None m1_rec = self.lhood_mnist( *self.decoder_mnist(m1_s_embeddings, class_embeddings)) results_rec['img_mnist'] = m1_rec if input_svhn is not None: m2_s_mu, m2_s_logvar = latents['img_svhn'][:2] if self.flags.factorized_representation: m2_s_embeddings = utils.reparameterize(mu=m2_s_mu, logvar=m2_s_logvar) else: m2_s_embeddings = None m2_rec = self.lhood_svhn( *self.decoder_svhn(m2_s_embeddings, class_embeddings)) results_rec['img_svhn'] = m2_rec if input_text is not None: m3_s_mu, m3_s_logvar = latents['text'][:2] if self.flags.factorized_representation: m3_s_embeddings = utils.reparameterize(mu=m3_s_mu, logvar=m3_s_logvar) else: m3_s_embeddings = None m3_rec = self.lhood_text( *self.decoder_text(m3_s_embeddings, class_embeddings)) results_rec['text'] = m3_rec results['rec'] = results_rec return results
def cond_generation_1a(self, latent_distributions, num_samples=None): if num_samples is None: num_samples = self.flags.batch_size style_latents = self.get_random_styles(num_samples) cond_gen_samples = dict() for k, key in enumerate(latent_distributions): [mu, logvar] = latent_distributions[key] content_rep = utils.reparameterize(mu=mu, logvar=logvar) latents = {'content': content_rep, 'style': style_latents} cond_gen_samples[key] = self.generate_from_latents(latents) return cond_gen_samples
def main(): global args cfg = parseArgs() if not os.path.exists(cfg.MISC.OUTPUT_PATH): os.makedirs(cfg.MISC.OUTPUT_PATH) encoderVis, encoderNir, netG = defineG(hdim=cfg.G.TRAIN.HDIM) netIP = defineIP(isTrain=False, ) print('==> Loading pre-trained identity preserving model from {}'.format( cfg.G.NET_IP)) checkpoint = torch.load(cfg.G.NET_IP) pretrainedDict = checkpoint['state_dict'] modelDict = netIP.state_dict() pretrainedDict = { k: v for k, v in pretrainedDict.items() if k in modelDict } modelDict.update(pretrainedDict) netIP.load_state_dict(modelDict) for param in netIP.parameters(): param.requires_grad = False # optimizer optimizer = torch.optim.Adam(list(netG.parameters()) + list(encoderVis.parameters()) + list(encoderNir.parameters()), lr=cfg.G.TRAIN.LR) scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=cfg.G.TRAIN.MILESTONE, gamma=0.1, last_epoch=-1) # resume if cfg.G.TRAIN.RESUME: encoderVis, encoderNir, netG, startEpoch = loadModel( cfg, encoderNir, encoderVis, netG) optimizer = loadOptimizer(cfg, optimizer) else: startEpoch = 0 # criterion l2Loss = torch.nn.MSELoss() l1Loss = torch.nn.L1Loss() smoothL1Loss = torch.nn.SmoothL1Loss() lossDict = {'l1': l1Loss, 'l2': l2Loss, 'smoothL1': smoothL1Loss} ipLoss = lossDict[cfg.G.TRAIN.IP_LOSS].cuda() pairLoss = lossDict[cfg.G.TRAIN.PAIR_LOSS].cuda() recLoss = lossDict[cfg.G.TRAIN.REC_LOSS].cuda() # dataloader trainLoader = torch.utils.data.DataLoader( GenDataset(imgRoot=cfg.G.DATASET.ROOT, protocolsRoot=cfg.G.DATASET.PROTOCOLS), batch_size=cfg.G.TRAIN.BATCH_SIZE, shuffle=True, num_workers=cfg.G.TRAIN.NUM_WORKERS, pin_memory=True, drop_last=False) # writer TIMESTAMP = "{0:%Y%m%dT%H%M%S}".format(datetime.now()) writer = SummaryWriter(logdir=os.path.join(cfg.MISC.OUTPUT_PATH, 'run', '{}'.format(cfg.CFG_NAME))) for epoch in range(startEpoch, cfg.G.TRAIN.EPOCH): batchTime = AverageMeter() dataTime = AverageMeter() losses = AverageMeter() recLosses = AverageMeter() klLosses = AverageMeter() mmdLosses = AverageMeter() ipLosses = AverageMeter() pairLosses = AverageMeter() encoderVis.train() encoderNir.train() netG.train() netIP.eval() startTime = time.time() for i, batch in enumerate(trainLoader): dataTime.update(time.time() - startTime) imgNir = Variable(batch['0'].cuda()) imgVis = Variable(batch['1'].cuda()) img = torch.cat((imgNir, imgVis), dim=1) # encoder forward muNir, logvarNir = encoderNir(imgNir) muVis, logvarVis = encoderVis(imgVis) # re-parametrization zNir = reparameterize(muNir, logvarNir) zVis = reparameterize(muVis, logvarVis) # generator rec = netG(torch.cat((zNir, zVis), dim=1)) # vae loss # lossRec = reconLoss(rec, img, True) / 2. lossRec = cfg.G.TRAIN.LAMBDA_REC * recLoss(rec, img) / 2.0 lossKL = cfg.G.TRAIN.LAMBDA_KL * ( klLoss(muNir, logvarNir).mean() + klLoss(muVis, logvarVis).mean()) / 2.0 # mmd loss lossMMD = cfg.G.TRAIN.LAMBDA_MMD * torch.abs( zNir.mean(dim=0) - zVis.mean(dim=0)).mean() # identity preserving loss recNir = rec[:, 0:3, :, :] recVis = rec[:, 3:6, :, :] embedNir = F.normalize(netIP(rgb2gray(imgNir))[0], p=2, dim=1) embedVis = F.normalize(netIP(rgb2gray(imgVis))[0], p=2, dim=1) recEmbedNir = F.normalize(netIP(rgb2gray(recNir))[0], p=2, dim=1) recEmbedVis = F.normalize(netIP(rgb2gray(recVis))[0], p=2, dim=1) lossIP = cfg.G.TRAIN.LAMBDA_IP * ( ipLoss(recEmbedNir, embedNir.detach()) + ipLoss(recEmbedVis, embedVis.detach())) / 2.0 lossPair = cfg.G.TRAIN.LAMBDA_PAIR * pairLoss( recEmbedNir, recEmbedVis) if epoch < 2: loss = lossRec + 0.01 * lossKL + 0.01 * lossMMD + 0.01 * lossIP + 0.01 * lossPair else: loss = lossRec + lossKL + lossMMD + lossIP + lossPair losses.update(loss.item()) recLosses.update(lossRec.item()) klLosses.update(lossKL.item()) mmdLosses.update(lossMMD.item()) ipLosses.update(lossIP.item()) pairLosses.update(lossPair.item()) # optimize optimizer.zero_grad() loss.backward() optimizer.step() batchTime.update(time.time() - startTime) startTime = time.time() scheduler.step(epoch) if i % cfg.G.TRAIN.PRINT_FREQ == 0: info = '==> Epoch: [{:0>4d}][{:3d}/{:3d}] Batch time: {:4.3f} Data time: {:4.3f} | '.format( epoch, i, len(trainLoader), batchTime.avg, dataTime.avg) info += 'Loss: rec: {:4.3f} kl: {:4.3f} mmd: {:4.3f} ip: {:4.8f} pair: {:4.8f}'.format( lossRec.item(), lossKL.item(), lossMMD.item(), lossIP.item(), lossPair.item()) print(info) # writer writer.add_scalar('loss/loss', losses.avg, epoch) writer.add_scalar('loss/recLoss', recLosses.avg, epoch) writer.add_scalar('loss/klLoss', klLosses.avg, epoch) writer.add_scalar('loss/mmdLoss', mmdLosses.avg, epoch) writer.add_scalar('loss/ipLoss', ipLosses.avg, epoch) writer.add_scalar('loss/pairLoss', pairLosses.avg, epoch) x = vutils.make_grid(imgNir.data, normalize=True, scale_each=True) writer.add_image('nir/imgNir', x, epoch) x = vutils.make_grid(imgVis.data, normalize=True, scale_each=True) writer.add_image('vis/imgVis', x, epoch) x = vutils.make_grid(recNir.data, normalize=True, scale_each=True) writer.add_image('nir/recNIR', x, epoch) x = vutils.make_grid(recVis.data, normalize=True, scale_each=True) writer.add_image('vis/recVis', x, epoch) noise = torch.zeros(cfg.G.TRAIN.BATCH_SIZE, cfg.G.TRAIN.HDIM).normal_(0, 1) noise = torch.cat((noise, noise), dim=1) noise = noise.cuda() fakeImg = netG(noise) x = vutils.make_grid(fakeImg[:, 0:3, :, :].data, normalize=True, scale_each=True) writer.add_image('fake/fakeNir', x, epoch) x = vutils.make_grid(fakeImg[:, 3:6, :, :].data, normalize=True, scale_each=True) writer.add_image('fake/fakeVis', x, epoch) # evaluation if not os.path.isdir(cfg.G.TEST.IMG_DUMP): os.makedirs(cfg.G.TEST.IMG_DUMP) if (epoch + 1) % cfg.G.TEST.FREQ == 0: noise = torch.zeros(cfg.G.TRAIN.BATCH_SIZE, cfg.G.TRAIN.HDIM).normal_(0, 1) noise = torch.cat((noise, noise), dim=1) noise = noise.cuda() fakeImg = netG(noise) vutils.save_image( fakeImg[:, 0:3, :, :].data, os.path.join( cfg.G.TEST.IMG_DUMP, '{}_epoch_{:03d}_fake_nir.png'.format(cfg.CFG_NAME, epoch))) vutils.save_image( fakeImg[:, 3:6, :, :].data, os.path.join( cfg.G.TEST.IMG_DUMP, '{}_epoch_{:03d}_fake_vis.png'.format(cfg.CFG_NAME, epoch))) vutils.save_image( imgNir.data, os.path.join( cfg.G.TEST.IMG_DUMP, '{}_epoch_{:03d}_img_nir.png'.format(cfg.CFG_NAME, epoch))) vutils.save_image( imgVis.data, os.path.join( cfg.G.TEST.IMG_DUMP, '{}_epoch_{:03d}_img_vis.png'.format(cfg.CFG_NAME, epoch))) vutils.save_image( recNir.data, os.path.join( cfg.G.TEST.IMG_DUMP, '{}_epoch_{:03d}_rec_nir.png'.format(cfg.CFG_NAME, epoch))) vutils.save_image( recVis.data, os.path.join( cfg.G.TEST.IMG_DUMP, '{}_epoch_{:03d}_rec_vis.png'.format(cfg.CFG_NAME, epoch))) if (epoch + 1) % cfg.G.TRAIN.SAVE_EPOCH == 0: saveOptimizer(cfg, optimizer, epoch) saveModel(cfg, encoderVis, encoderNir, netG, epoch)
def generate_swapping_plot(flags, epoch, model, samples, alphabet): rec_i_in_i_out = Variable( torch.zeros([121, 3, flags.img_size, flags.img_size], dtype=torch.float32)) rec_i_in_t_out = Variable( torch.zeros([121, 3, flags.img_size, flags.img_size], dtype=torch.float32)) rec_t_in_i_out = Variable( torch.zeros([121, 3, flags.img_size, flags.img_size], dtype=torch.float32)) rec_t_in_t_out = Variable( torch.zeros([121, 3, flags.img_size, flags.img_size], dtype=torch.float32)) rec_i_in_i_out = rec_i_in_i_out.to(flags.device) rec_i_in_t_out = rec_i_in_t_out.to(flags.device) rec_t_in_i_out = rec_t_in_i_out.to(flags.device) rec_t_in_t_out = rec_t_in_t_out.to(flags.device) # ground truth: samples1 -> style (rows), samples2 -> content (cols) img_size = torch.Size((3, flags.img_size, flags.img_size)) for i in range(len(samples)): c_text_sample = plot.text_to_pil_celeba(samples[i][1].unsqueeze(0), img_size, alphabet) c_img_sample = samples[i][0].squeeze() s_text_sample = c_text_sample.clone() s_img_sample = c_img_sample.clone() rec_i_in_i_out[i + 1, :, :, :] = c_img_sample rec_i_in_i_out[(i + 1) * 11, :, :, :] = s_img_sample rec_i_in_t_out[i + 1, :, :, :] = c_img_sample rec_i_in_t_out[(i + 1) * 11, :, :, :] = s_text_sample rec_t_in_i_out[i + 1, :, :, :] = c_text_sample rec_t_in_i_out[(i + 1) * 11, :, :, :] = s_img_sample rec_t_in_t_out[i + 1, :, :, :] = c_text_sample rec_t_in_t_out[(i + 1) * 11, :, :, :] = s_text_sample # style transfer for i in range(len(samples)): for j in range(len(samples)): l_style = model.inference(samples[i][0].unsqueeze(0), samples[i][1].unsqueeze(0)) l_content = model.inference(samples[j][0].unsqueeze(0), samples[j][1].unsqueeze(0)) l_c_img = l_content['img_celeba'] l_c_text = l_content['text'] l_s_img = l_style['img_celeba'] l_s_text = l_style['text'] s_img_emb = utils.reparameterize(mu=l_s_img[0], logvar=l_s_img[1]) c_img_emb = utils.reparameterize(mu=l_c_img[2], logvar=l_c_img[3]) s_text_emb = utils.reparameterize(mu=l_s_text[0], logvar=l_s_text[1]) c_text_emb = utils.reparameterize(mu=l_c_text[2], logvar=l_c_text[3]) style_emb = { 'img_celeba': s_img_emb, 'text': s_text_emb } emb_c_img = { 'content': c_img_emb, 'style': style_emb } emb_c_text = { 'content': c_text_emb, 'style': style_emb } img_c_samples = model.generate_from_latents(emb_c_img) text_c_samples = model.generate_from_latents(emb_c_text) i_in_i_out = img_c_samples['img_celeba'] i_in_t_out = img_c_samples['text'] t_in_i_out = text_c_samples['img_celeba'] t_in_t_out = text_c_samples['text'] rec_i_in_i_out[(i + 1) * 11 + (j + 1), :, :, :] = i_in_i_out rec_i_in_t_out[(i + 1) * 11 + (j + 1), :, :, :] = plot.text_to_pil_celeba( i_in_t_out, img_size, alphabet) rec_t_in_i_out[(i + 1) * 11 + (j + 1), :, :, :] = t_in_i_out rec_t_in_t_out[(i + 1) * 11 + (j + 1), :, :, :] = plot.text_to_pil_celeba( t_in_t_out, img_size, alphabet) fp_i_in_i_out = os.path.join( flags.dir_swapping, 'swap_i_to_i_epoch_' + str(epoch).zfill(4) + '.png') fp_i_in_t_out = os.path.join( flags.dir_swapping, 'swap_i_to_t_epoch_' + str(epoch).zfill(4) + '.png') fp_t_in_i_out = os.path.join( flags.dir_swapping, 'swap_t_to_i_epoch_' + str(epoch).zfill(4) + '.png') fp_t_in_t_out = os.path.join( flags.dir_swapping, 'swap_t_to_t_epoch_' + str(epoch).zfill(4) + '.png') plot_i_i = plot.create_fig(fp_i_in_i_out, rec_i_in_i_out, 11, flags.save_plot_images) plot_i_t = plot.create_fig(fp_i_in_t_out, rec_i_in_t_out, 11, flags.save_plot_images) plot_t_i = plot.create_fig(fp_t_in_i_out, rec_t_in_i_out, 11, flags.save_plot_images) plot_t_t = plot.create_fig(fp_t_in_t_out, rec_t_in_t_out, 11, flags.save_plot_images) plots_c_img = { 'img_celeba': plot_i_i, 'text': plot_i_t } plots_c_text = { 'img_celeba': plot_t_i, 'text': plot_t_t } plots = { 'img_celeba': plots_c_img, 'text': plots_c_text } return plots
def generate_conditional_fig(flags, epoch, model, samples, alphabet): rec_i_in_i_out = Variable( torch.zeros([110, 3, flags.img_size, flags.img_size], dtype=torch.float32)) rec_i_in_t_out = Variable( torch.zeros([110, 3, flags.img_size, flags.img_size], dtype=torch.float32)) rec_t_in_i_out = Variable( torch.zeros([110, 3, flags.img_size, flags.img_size], dtype=torch.float32)) rec_t_in_t_out = Variable( torch.zeros([110, 3, flags.img_size, flags.img_size], dtype=torch.float32)) rec_i_in_i_out = rec_i_in_i_out.to(flags.device) rec_i_in_t_out = rec_i_in_t_out.to(flags.device) rec_t_in_i_out = rec_t_in_i_out.to(flags.device) rec_t_in_t_out = rec_t_in_t_out.to(flags.device) # get style from random sampling zi_img = Variable(torch.randn(len(samples), flags.style_m1_dim)).to(flags.device) zi_text = Variable(torch.randn(len(samples), flags.style_m2_dim)).to(flags.device) # ground truth: samples1 -> style (rows), samples2 -> content (cols) img_size = torch.Size((3, flags.img_size, flags.img_size)) for i in range(len(samples)): c_sample_text = plot.text_to_pil_celeba(samples[i][1].unsqueeze(0), img_size, alphabet) c_sample_img = samples[i][0].squeeze() rec_i_in_i_out[i, :, :, :] = c_sample_img rec_i_in_t_out[i, :, :, :] = c_sample_img rec_t_in_i_out[i, :, :, :] = c_sample_text rec_t_in_t_out[i, :, :, :] = c_sample_text # style transfer random_style = { 'img_celeba': None, 'text': None } for i in range(len(samples)): for j in range(len(samples)): latents = model.inference(input_img=samples[j][0].unsqueeze(0), input_text=samples[j][1].unsqueeze(0)) l_c_img = latents['img_celeba'][2:] l_c_text = latents['text'][2:] if flags.factorized_representation: random_style = { 'img_celeba': zi_img[i].unsqueeze(0), 'text': zi_text[i].unsqueeze(0) } emb_c_img = utils.reparameterize(l_c_img[0], l_c_img[1]) emb_c_text = utils.reparameterize(l_c_text[0], l_c_text[1]) emb_img = { 'content': emb_c_img, 'style': random_style } emb_text = { 'content': emb_c_text, 'style': random_style } img_cond_gen = model.generate_from_latents(emb_img) text_cond_gen = model.generate_from_latents(emb_text) i_in_i_out = img_cond_gen['img_celeba'].squeeze(0) i_in_t_out = plot.text_to_pil_celeba(img_cond_gen['text'], img_size, alphabet) t_in_i_out = text_cond_gen['img_celeba'].squeeze(0) t_in_t_out = plot.text_to_pil_celeba(text_cond_gen['text'], img_size, alphabet) rec_i_in_i_out[(i + 1) * 10 + j, :, :, :] = i_in_i_out rec_i_in_t_out[(i + 1) * 10 + j, :, :, :] = i_in_t_out rec_t_in_i_out[(i + 1) * 10 + j, :, :, :] = t_in_i_out rec_t_in_t_out[(i + 1) * 10 + j, :, :, :] = t_in_t_out fp_i_in_i_out = os.path.join( flags.dir_cond_gen, 'cond_gen_img_img_epoch_' + str(epoch).zfill(4) + '.png') fp_i_in_t_out = os.path.join( flags.dir_cond_gen, 'cond_gen_img_text_epoch_' + str(epoch).zfill(4) + '.png') fp_t_in_i_out = os.path.join( flags.dir_cond_gen, 'cond_gen_text_img_epoch_' + str(epoch).zfill(4) + '.png') fp_t_in_t_out = os.path.join( flags.dir_cond_gen, 'cond_gen_text_text_epoch_' + str(epoch).zfill(4) + '.png') plot_i_i = plot.create_fig(fp_i_in_i_out, rec_i_in_i_out, 10, flags.save_plot_images) plot_i_t = plot.create_fig(fp_i_in_t_out, rec_i_in_t_out, 10, flags.save_plot_images) plot_t_i = plot.create_fig(fp_t_in_i_out, rec_t_in_i_out, 10, flags.save_plot_images) plot_t_t = plot.create_fig(fp_t_in_t_out, rec_t_in_t_out, 10, flags.save_plot_images) img_cond = { 'img_celeba': plot_i_i, 'text': plot_i_t } text_cond = { 'img_celeba': plot_t_i, 'text': plot_t_t } plots = { 'img_celeba': img_cond, 'text': text_cond } return plots