Exemplo n.º 1
0
    def forward(self, input_img=None, input_text=None):
        results = dict();
        latents = self.inference(input_img=input_img, input_text=input_text);
        results['latents'] = latents;
        mus = latents['mus']
        logvars = latents['logvars']
        weights = latents['weights']
        if input_img is not None and input_text is not None:
            div = self.calc_joint_divergence(mus, logvars, weights);
            for k,key in enumerate(div.keys()):
                results[key] = div[key];

        results['group_distr'] = latents['joint'];
        class_embeddings = utils.reparameterize(latents['joint'][0],
                                                latents['joint'][1])

        if self.flags.factorized_representation:
            if input_img is not None:
                [m1_s_mu, m1_s_logvar] = latents['img_celeba'][:2];
                m1_style_latent_embeddings = utils.reparameterize(mu=m1_s_mu, logvar=m1_s_logvar)
            if input_text is not None:
                [m2_s_mu, m2_s_logvar] = latents['text'][:2];
                m2_style_latent_embeddings = utils.reparameterize(mu=m2_s_mu, logvar=m2_s_logvar)
        else:
            m1_style_latent_embeddings = None;
            m2_style_latent_embeddings = None;

        m1_rec = None;
        m2_rec = None;
        if input_img is not None:
            m1_rec = self.lhood_celeba(*self.decoder_img(m1_style_latent_embeddings, class_embeddings));
        if input_text is not None:
            m2_rec = self.lhood_text(*self.decoder_text(m2_style_latent_embeddings, class_embeddings));
        results['rec'] = {'img_celeba': m1_rec, 'text': m2_rec};
        return results;
Exemplo n.º 2
0
    def cond_generation(self, latent_distributions):
        if 'img_celeba' in latent_distributions:
            [m1_mu, m1_logvar] = latent_distributions['img_celeba'];
            content_cond_m1 = utils.reparameterize(mu=m1_mu, logvar=m1_logvar);
            num_samples = m1_mu.shape[0];
        if 'text' in latent_distributions:
            [m2_mu, m2_logvar] = latent_distributions['text'];
            content_cond_m2 = utils.reparameterize(mu=m2_mu, logvar=m2_logvar);
            num_samples = m2_mu.shape[0];

        if self.flags.factorized_representation:
            random_style_m1 = torch.randn(num_samples, self.flags.style_m1_dim);
            random_style_m2 = torch.randn(num_samples, self.flags.style_m2_dim);
            random_style_m1 = random_style_m1.to(self.flags.device)
            random_style_m2 = random_style_m2.to(self.flags.device)
        else:
            random_style_m1 = None;
            random_style_m2 = None;

        style_latents = {'img_celeba': random_style_m1, 'text': random_style_m2};
        cond_gen_samples = dict();
        if 'img_celeba' in latent_distributions:
            latents_mnist = {'content': content_cond_m1,
                             'style': style_latents}
            cond_gen_samples['img_celeba'] = self.generate_from_latents(latents_mnist);
        if 'text' in latent_distributions:
            latents_svhn = {'content': content_cond_m2,
                             'style': style_latents}
            cond_gen_samples['text'] = self.generate_from_latents(latents_svhn);
        return cond_gen_samples;
Exemplo n.º 3
0
    def cond_generation_2a(self, latent_distribution_pairs, num_samples=None):
        if num_samples is None:
            num_samples = self.flags.batch_size

        mu0 = torch.zeros(1, num_samples, self.flags.class_dim)
        logvar0 = torch.zeros(1, num_samples, self.flags.class_dim)
        mu0 = mu0.to(self.flags.device)
        logvar0 = logvar0.to(self.flags.device)
        style_latents = self.get_random_styles(num_samples)
        cond_gen_2a = dict()
        for p, pair in enumerate(latent_distribution_pairs.keys()):
            ld_pair = latent_distribution_pairs[pair]
            mu_list = [mu0]
            logvar_list = [logvar0]
            for k, key in enumerate(ld_pair['latents'].keys()):
                mu_list.append(ld_pair['latents'][key][0].unsqueeze(0))
                logvar_list.append(ld_pair['latents'][key][1].unsqueeze(0))
            mus = torch.cat(mu_list, dim=0)
            logvars = torch.cat(logvar_list, dim=0)
            weights_pair = ld_pair['weights']
            weights_pair.insert(0, self.weights[0])
            weights_pair = utils.reweight_weights(torch.Tensor(weights_pair))
            mu_joint, logvar_joint = self.modality_fusion(
                mus, logvars, weights_pair)
            #mu_joint, logvar_joint = poe(mus, logvars);
            c_emb = utils.reparameterize(mu_joint, logvar_joint)
            l_2a = {
                'content': c_emb,
                'style': style_latents
            }
            cond_gen_2a[pair] = self.generate_from_latents(l_2a)
        return cond_gen_2a
Exemplo n.º 4
0
def get_latent_samples(flags, latents, mod_names):
    l_c = latents['content']
    l_s = latents['style']
    c_emb = utils.reparameterize(l_c[0], l_c[1])
    styles = dict()
    c = {'mu': l_c[0], 'logvar': l_c[1], 'z': c_emb}
    if flags.factorized_representation:
        for k, key in enumerate(l_s.keys()):
            s_emb = utils.reparameterize(l_s[key][0], l_s[key][1])
            s = {'mu': l_s[key][0], 'logvar': l_s[key][1], 'z': s_emb}
            styles[key] = s
    else:
        for k, key in enumerate(mod_names):
            styles[key] = None
    emb = {'content': c, 'style': styles}
    return emb
Exemplo n.º 5
0
    def forward(self, input_mnist=None, input_svhn=None, input_text=None):
        latents = self.inference(input_mnist, input_svhn, input_text)
        results = dict()
        results['latents'] = latents

        results['group_distr'] = latents['joint']
        class_embeddings = utils.reparameterize(latents['joint'][0],
                                                latents['joint'][1])
        div = self.calc_joint_divergence(latents['mus'], latents['logvars'],
                                         latents['weights'])
        for k, key in enumerate(div.keys()):
            results[key] = div[key]

        results_rec = dict()
        if input_mnist is not None:
            m1_s_mu, m1_s_logvar = latents['img_mnist'][:2]
            if self.flags.factorized_representation:
                m1_s_embeddings = utils.reparameterize(mu=m1_s_mu,
                                                       logvar=m1_s_logvar)
            else:
                m1_s_embeddings = None
            m1_rec = self.lhood_mnist(
                *self.decoder_mnist(m1_s_embeddings, class_embeddings))
            results_rec['img_mnist'] = m1_rec
        if input_svhn is not None:
            m2_s_mu, m2_s_logvar = latents['img_svhn'][:2]
            if self.flags.factorized_representation:
                m2_s_embeddings = utils.reparameterize(mu=m2_s_mu,
                                                       logvar=m2_s_logvar)
            else:
                m2_s_embeddings = None
            m2_rec = self.lhood_svhn(
                *self.decoder_svhn(m2_s_embeddings, class_embeddings))
            results_rec['img_svhn'] = m2_rec
        if input_text is not None:
            m3_s_mu, m3_s_logvar = latents['text'][:2]
            if self.flags.factorized_representation:
                m3_s_embeddings = utils.reparameterize(mu=m3_s_mu,
                                                       logvar=m3_s_logvar)
            else:
                m3_s_embeddings = None
            m3_rec = self.lhood_text(
                *self.decoder_text(m3_s_embeddings, class_embeddings))
            results_rec['text'] = m3_rec
        results['rec'] = results_rec
        return results
Exemplo n.º 6
0
    def cond_generation_1a(self, latent_distributions, num_samples=None):
        if num_samples is None:
            num_samples = self.flags.batch_size

        style_latents = self.get_random_styles(num_samples)
        cond_gen_samples = dict()
        for k, key in enumerate(latent_distributions):
            [mu, logvar] = latent_distributions[key]
            content_rep = utils.reparameterize(mu=mu, logvar=logvar)
            latents = {'content': content_rep, 'style': style_latents}
            cond_gen_samples[key] = self.generate_from_latents(latents)
        return cond_gen_samples
Exemplo n.º 7
0
def main():
    global args
    cfg = parseArgs()

    if not os.path.exists(cfg.MISC.OUTPUT_PATH):
        os.makedirs(cfg.MISC.OUTPUT_PATH)

    encoderVis, encoderNir, netG = defineG(hdim=cfg.G.TRAIN.HDIM)
    netIP = defineIP(isTrain=False, )

    print('==> Loading pre-trained identity preserving model from {}'.format(
        cfg.G.NET_IP))
    checkpoint = torch.load(cfg.G.NET_IP)
    pretrainedDict = checkpoint['state_dict']
    modelDict = netIP.state_dict()
    pretrainedDict = {
        k: v
        for k, v in pretrainedDict.items() if k in modelDict
    }
    modelDict.update(pretrainedDict)
    netIP.load_state_dict(modelDict)

    for param in netIP.parameters():
        param.requires_grad = False

    # optimizer
    optimizer = torch.optim.Adam(list(netG.parameters()) +
                                 list(encoderVis.parameters()) +
                                 list(encoderNir.parameters()),
                                 lr=cfg.G.TRAIN.LR)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=cfg.G.TRAIN.MILESTONE, gamma=0.1, last_epoch=-1)

    # resume
    if cfg.G.TRAIN.RESUME:
        encoderVis, encoderNir, netG, startEpoch = loadModel(
            cfg, encoderNir, encoderVis, netG)
        optimizer = loadOptimizer(cfg, optimizer)
    else:
        startEpoch = 0

    # criterion
    l2Loss = torch.nn.MSELoss()
    l1Loss = torch.nn.L1Loss()
    smoothL1Loss = torch.nn.SmoothL1Loss()
    lossDict = {'l1': l1Loss, 'l2': l2Loss, 'smoothL1': smoothL1Loss}
    ipLoss = lossDict[cfg.G.TRAIN.IP_LOSS].cuda()
    pairLoss = lossDict[cfg.G.TRAIN.PAIR_LOSS].cuda()
    recLoss = lossDict[cfg.G.TRAIN.REC_LOSS].cuda()

    # dataloader
    trainLoader = torch.utils.data.DataLoader(
        GenDataset(imgRoot=cfg.G.DATASET.ROOT,
                   protocolsRoot=cfg.G.DATASET.PROTOCOLS),
        batch_size=cfg.G.TRAIN.BATCH_SIZE,
        shuffle=True,
        num_workers=cfg.G.TRAIN.NUM_WORKERS,
        pin_memory=True,
        drop_last=False)

    # writer
    TIMESTAMP = "{0:%Y%m%dT%H%M%S}".format(datetime.now())
    writer = SummaryWriter(logdir=os.path.join(cfg.MISC.OUTPUT_PATH, 'run',
                                               '{}'.format(cfg.CFG_NAME)))

    for epoch in range(startEpoch, cfg.G.TRAIN.EPOCH):
        batchTime = AverageMeter()
        dataTime = AverageMeter()
        losses = AverageMeter()
        recLosses = AverageMeter()
        klLosses = AverageMeter()
        mmdLosses = AverageMeter()
        ipLosses = AverageMeter()
        pairLosses = AverageMeter()

        encoderVis.train()
        encoderNir.train()
        netG.train()
        netIP.eval()

        startTime = time.time()
        for i, batch in enumerate(trainLoader):
            dataTime.update(time.time() - startTime)

            imgNir = Variable(batch['0'].cuda())
            imgVis = Variable(batch['1'].cuda())

            img = torch.cat((imgNir, imgVis), dim=1)

            # encoder forward
            muNir, logvarNir = encoderNir(imgNir)
            muVis, logvarVis = encoderVis(imgVis)

            # re-parametrization
            zNir = reparameterize(muNir, logvarNir)
            zVis = reparameterize(muVis, logvarVis)

            # generator
            rec = netG(torch.cat((zNir, zVis), dim=1))

            # vae loss
            # lossRec = reconLoss(rec, img, True) / 2.
            lossRec = cfg.G.TRAIN.LAMBDA_REC * recLoss(rec, img) / 2.0
            lossKL = cfg.G.TRAIN.LAMBDA_KL * (
                klLoss(muNir, logvarNir).mean() +
                klLoss(muVis, logvarVis).mean()) / 2.0

            # mmd loss
            lossMMD = cfg.G.TRAIN.LAMBDA_MMD * torch.abs(
                zNir.mean(dim=0) - zVis.mean(dim=0)).mean()

            # identity preserving loss
            recNir = rec[:, 0:3, :, :]
            recVis = rec[:, 3:6, :, :]

            embedNir = F.normalize(netIP(rgb2gray(imgNir))[0], p=2, dim=1)
            embedVis = F.normalize(netIP(rgb2gray(imgVis))[0], p=2, dim=1)

            recEmbedNir = F.normalize(netIP(rgb2gray(recNir))[0], p=2, dim=1)
            recEmbedVis = F.normalize(netIP(rgb2gray(recVis))[0], p=2, dim=1)

            lossIP = cfg.G.TRAIN.LAMBDA_IP * (
                ipLoss(recEmbedNir, embedNir.detach()) +
                ipLoss(recEmbedVis, embedVis.detach())) / 2.0
            lossPair = cfg.G.TRAIN.LAMBDA_PAIR * pairLoss(
                recEmbedNir, recEmbedVis)

            if epoch < 2:
                loss = lossRec + 0.01 * lossKL + 0.01 * lossMMD + 0.01 * lossIP + 0.01 * lossPair
            else:
                loss = lossRec + lossKL + lossMMD + lossIP + lossPair
            losses.update(loss.item())
            recLosses.update(lossRec.item())
            klLosses.update(lossKL.item())
            mmdLosses.update(lossMMD.item())
            ipLosses.update(lossIP.item())
            pairLosses.update(lossPair.item())

            # optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            batchTime.update(time.time() - startTime)
            startTime = time.time()

            scheduler.step(epoch)

            if i % cfg.G.TRAIN.PRINT_FREQ == 0:
                info = '==> Epoch: [{:0>4d}][{:3d}/{:3d}] Batch time: {:4.3f} Data time: {:4.3f} | '.format(
                    epoch, i, len(trainLoader), batchTime.avg, dataTime.avg)
                info += 'Loss: rec: {:4.3f} kl: {:4.3f} mmd: {:4.3f} ip: {:4.8f} pair: {:4.8f}'.format(
                    lossRec.item(), lossKL.item(), lossMMD.item(),
                    lossIP.item(), lossPair.item())
                print(info)

        # writer
        writer.add_scalar('loss/loss', losses.avg, epoch)
        writer.add_scalar('loss/recLoss', recLosses.avg, epoch)
        writer.add_scalar('loss/klLoss', klLosses.avg, epoch)
        writer.add_scalar('loss/mmdLoss', mmdLosses.avg, epoch)
        writer.add_scalar('loss/ipLoss', ipLosses.avg, epoch)
        writer.add_scalar('loss/pairLoss', pairLosses.avg, epoch)

        x = vutils.make_grid(imgNir.data, normalize=True, scale_each=True)
        writer.add_image('nir/imgNir', x, epoch)
        x = vutils.make_grid(imgVis.data, normalize=True, scale_each=True)
        writer.add_image('vis/imgVis', x, epoch)
        x = vutils.make_grid(recNir.data, normalize=True, scale_each=True)
        writer.add_image('nir/recNIR', x, epoch)
        x = vutils.make_grid(recVis.data, normalize=True, scale_each=True)
        writer.add_image('vis/recVis', x, epoch)

        noise = torch.zeros(cfg.G.TRAIN.BATCH_SIZE,
                            cfg.G.TRAIN.HDIM).normal_(0, 1)
        noise = torch.cat((noise, noise), dim=1)
        noise = noise.cuda()
        fakeImg = netG(noise)
        x = vutils.make_grid(fakeImg[:, 0:3, :, :].data,
                             normalize=True,
                             scale_each=True)
        writer.add_image('fake/fakeNir', x, epoch)
        x = vutils.make_grid(fakeImg[:, 3:6, :, :].data,
                             normalize=True,
                             scale_each=True)
        writer.add_image('fake/fakeVis', x, epoch)

        # evaluation
        if not os.path.isdir(cfg.G.TEST.IMG_DUMP):
            os.makedirs(cfg.G.TEST.IMG_DUMP)
        if (epoch + 1) % cfg.G.TEST.FREQ == 0:
            noise = torch.zeros(cfg.G.TRAIN.BATCH_SIZE,
                                cfg.G.TRAIN.HDIM).normal_(0, 1)
            noise = torch.cat((noise, noise), dim=1)
            noise = noise.cuda()

            fakeImg = netG(noise)

            vutils.save_image(
                fakeImg[:, 0:3, :, :].data,
                os.path.join(
                    cfg.G.TEST.IMG_DUMP,
                    '{}_epoch_{:03d}_fake_nir.png'.format(cfg.CFG_NAME,
                                                          epoch)))
            vutils.save_image(
                fakeImg[:, 3:6, :, :].data,
                os.path.join(
                    cfg.G.TEST.IMG_DUMP,
                    '{}_epoch_{:03d}_fake_vis.png'.format(cfg.CFG_NAME,
                                                          epoch)))
            vutils.save_image(
                imgNir.data,
                os.path.join(
                    cfg.G.TEST.IMG_DUMP,
                    '{}_epoch_{:03d}_img_nir.png'.format(cfg.CFG_NAME, epoch)))
            vutils.save_image(
                imgVis.data,
                os.path.join(
                    cfg.G.TEST.IMG_DUMP,
                    '{}_epoch_{:03d}_img_vis.png'.format(cfg.CFG_NAME, epoch)))
            vutils.save_image(
                recNir.data,
                os.path.join(
                    cfg.G.TEST.IMG_DUMP,
                    '{}_epoch_{:03d}_rec_nir.png'.format(cfg.CFG_NAME, epoch)))
            vutils.save_image(
                recVis.data,
                os.path.join(
                    cfg.G.TEST.IMG_DUMP,
                    '{}_epoch_{:03d}_rec_vis.png'.format(cfg.CFG_NAME, epoch)))

        if (epoch + 1) % cfg.G.TRAIN.SAVE_EPOCH == 0:
            saveOptimizer(cfg, optimizer, epoch)
            saveModel(cfg, encoderVis, encoderNir, netG, epoch)
Exemplo n.º 8
0
def generate_swapping_plot(flags, epoch, model, samples, alphabet):
    rec_i_in_i_out = Variable(
        torch.zeros([121, 3, flags.img_size, flags.img_size],
                    dtype=torch.float32))
    rec_i_in_t_out = Variable(
        torch.zeros([121, 3, flags.img_size, flags.img_size],
                    dtype=torch.float32))
    rec_t_in_i_out = Variable(
        torch.zeros([121, 3, flags.img_size, flags.img_size],
                    dtype=torch.float32))
    rec_t_in_t_out = Variable(
        torch.zeros([121, 3, flags.img_size, flags.img_size],
                    dtype=torch.float32))
    rec_i_in_i_out = rec_i_in_i_out.to(flags.device)
    rec_i_in_t_out = rec_i_in_t_out.to(flags.device)
    rec_t_in_i_out = rec_t_in_i_out.to(flags.device)
    rec_t_in_t_out = rec_t_in_t_out.to(flags.device)

    # ground truth: samples1 -> style (rows), samples2 -> content (cols)
    img_size = torch.Size((3, flags.img_size, flags.img_size))
    for i in range(len(samples)):
        c_text_sample = plot.text_to_pil_celeba(samples[i][1].unsqueeze(0),
                                                img_size, alphabet)
        c_img_sample = samples[i][0].squeeze()
        s_text_sample = c_text_sample.clone()
        s_img_sample = c_img_sample.clone()
        rec_i_in_i_out[i + 1, :, :, :] = c_img_sample
        rec_i_in_i_out[(i + 1) * 11, :, :, :] = s_img_sample
        rec_i_in_t_out[i + 1, :, :, :] = c_img_sample
        rec_i_in_t_out[(i + 1) * 11, :, :, :] = s_text_sample
        rec_t_in_i_out[i + 1, :, :, :] = c_text_sample
        rec_t_in_i_out[(i + 1) * 11, :, :, :] = s_img_sample
        rec_t_in_t_out[i + 1, :, :, :] = c_text_sample
        rec_t_in_t_out[(i + 1) * 11, :, :, :] = s_text_sample

    # style transfer
    for i in range(len(samples)):
        for j in range(len(samples)):
            l_style = model.inference(samples[i][0].unsqueeze(0),
                                      samples[i][1].unsqueeze(0))
            l_content = model.inference(samples[j][0].unsqueeze(0),
                                        samples[j][1].unsqueeze(0))

            l_c_img = l_content['img_celeba']
            l_c_text = l_content['text']
            l_s_img = l_style['img_celeba']
            l_s_text = l_style['text']
            s_img_emb = utils.reparameterize(mu=l_s_img[0], logvar=l_s_img[1])
            c_img_emb = utils.reparameterize(mu=l_c_img[2], logvar=l_c_img[3])
            s_text_emb = utils.reparameterize(mu=l_s_text[0],
                                              logvar=l_s_text[1])
            c_text_emb = utils.reparameterize(mu=l_c_text[2],
                                              logvar=l_c_text[3])
            style_emb = {
                'img_celeba': s_img_emb,
                'text': s_text_emb
            }
            emb_c_img = {
                'content': c_img_emb,
                'style': style_emb
            }
            emb_c_text = {
                'content': c_text_emb,
                'style': style_emb
            }

            img_c_samples = model.generate_from_latents(emb_c_img)
            text_c_samples = model.generate_from_latents(emb_c_text)
            i_in_i_out = img_c_samples['img_celeba']
            i_in_t_out = img_c_samples['text']
            t_in_i_out = text_c_samples['img_celeba']
            t_in_t_out = text_c_samples['text']
            rec_i_in_i_out[(i + 1) * 11 + (j + 1), :, :, :] = i_in_i_out
            rec_i_in_t_out[(i + 1) * 11 +
                           (j + 1), :, :, :] = plot.text_to_pil_celeba(
                               i_in_t_out, img_size, alphabet)
            rec_t_in_i_out[(i + 1) * 11 + (j + 1), :, :, :] = t_in_i_out
            rec_t_in_t_out[(i + 1) * 11 +
                           (j + 1), :, :, :] = plot.text_to_pil_celeba(
                               t_in_t_out, img_size, alphabet)
    fp_i_in_i_out = os.path.join(
        flags.dir_swapping,
        'swap_i_to_i_epoch_' + str(epoch).zfill(4) + '.png')
    fp_i_in_t_out = os.path.join(
        flags.dir_swapping,
        'swap_i_to_t_epoch_' + str(epoch).zfill(4) + '.png')
    fp_t_in_i_out = os.path.join(
        flags.dir_swapping,
        'swap_t_to_i_epoch_' + str(epoch).zfill(4) + '.png')
    fp_t_in_t_out = os.path.join(
        flags.dir_swapping,
        'swap_t_to_t_epoch_' + str(epoch).zfill(4) + '.png')
    plot_i_i = plot.create_fig(fp_i_in_i_out, rec_i_in_i_out, 11,
                               flags.save_plot_images)
    plot_i_t = plot.create_fig(fp_i_in_t_out, rec_i_in_t_out, 11,
                               flags.save_plot_images)
    plot_t_i = plot.create_fig(fp_t_in_i_out, rec_t_in_i_out, 11,
                               flags.save_plot_images)
    plot_t_t = plot.create_fig(fp_t_in_t_out, rec_t_in_t_out, 11,
                               flags.save_plot_images)
    plots_c_img = {
        'img_celeba': plot_i_i,
        'text': plot_i_t
    }
    plots_c_text = {
        'img_celeba': plot_t_i,
        'text': plot_t_t
    }
    plots = {
        'img_celeba': plots_c_img,
        'text': plots_c_text
    }
    return plots
Exemplo n.º 9
0
def generate_conditional_fig(flags, epoch, model, samples, alphabet):
    rec_i_in_i_out = Variable(
        torch.zeros([110, 3, flags.img_size, flags.img_size],
                    dtype=torch.float32))
    rec_i_in_t_out = Variable(
        torch.zeros([110, 3, flags.img_size, flags.img_size],
                    dtype=torch.float32))
    rec_t_in_i_out = Variable(
        torch.zeros([110, 3, flags.img_size, flags.img_size],
                    dtype=torch.float32))
    rec_t_in_t_out = Variable(
        torch.zeros([110, 3, flags.img_size, flags.img_size],
                    dtype=torch.float32))
    rec_i_in_i_out = rec_i_in_i_out.to(flags.device)
    rec_i_in_t_out = rec_i_in_t_out.to(flags.device)
    rec_t_in_i_out = rec_t_in_i_out.to(flags.device)
    rec_t_in_t_out = rec_t_in_t_out.to(flags.device)
    # get style from random sampling
    zi_img = Variable(torch.randn(len(samples),
                                  flags.style_m1_dim)).to(flags.device)
    zi_text = Variable(torch.randn(len(samples),
                                   flags.style_m2_dim)).to(flags.device)

    # ground truth: samples1 -> style (rows), samples2 -> content (cols)
    img_size = torch.Size((3, flags.img_size, flags.img_size))
    for i in range(len(samples)):
        c_sample_text = plot.text_to_pil_celeba(samples[i][1].unsqueeze(0),
                                                img_size, alphabet)
        c_sample_img = samples[i][0].squeeze()
        rec_i_in_i_out[i, :, :, :] = c_sample_img
        rec_i_in_t_out[i, :, :, :] = c_sample_img
        rec_t_in_i_out[i, :, :, :] = c_sample_text
        rec_t_in_t_out[i, :, :, :] = c_sample_text

    # style transfer
    random_style = {
        'img_celeba': None,
        'text': None
    }
    for i in range(len(samples)):
        for j in range(len(samples)):
            latents = model.inference(input_img=samples[j][0].unsqueeze(0),
                                      input_text=samples[j][1].unsqueeze(0))
            l_c_img = latents['img_celeba'][2:]
            l_c_text = latents['text'][2:]
            if flags.factorized_representation:
                random_style = {
                    'img_celeba': zi_img[i].unsqueeze(0),
                    'text': zi_text[i].unsqueeze(0)
                }
            emb_c_img = utils.reparameterize(l_c_img[0], l_c_img[1])
            emb_c_text = utils.reparameterize(l_c_text[0], l_c_text[1])
            emb_img = {
                'content': emb_c_img,
                'style': random_style
            }
            emb_text = {
                'content': emb_c_text,
                'style': random_style
            }
            img_cond_gen = model.generate_from_latents(emb_img)
            text_cond_gen = model.generate_from_latents(emb_text)
            i_in_i_out = img_cond_gen['img_celeba'].squeeze(0)
            i_in_t_out = plot.text_to_pil_celeba(img_cond_gen['text'],
                                                 img_size, alphabet)
            t_in_i_out = text_cond_gen['img_celeba'].squeeze(0)
            t_in_t_out = plot.text_to_pil_celeba(text_cond_gen['text'],
                                                 img_size, alphabet)
            rec_i_in_i_out[(i + 1) * 10 + j, :, :, :] = i_in_i_out
            rec_i_in_t_out[(i + 1) * 10 + j, :, :, :] = i_in_t_out
            rec_t_in_i_out[(i + 1) * 10 + j, :, :, :] = t_in_i_out
            rec_t_in_t_out[(i + 1) * 10 + j, :, :, :] = t_in_t_out

    fp_i_in_i_out = os.path.join(
        flags.dir_cond_gen,
        'cond_gen_img_img_epoch_' + str(epoch).zfill(4) + '.png')
    fp_i_in_t_out = os.path.join(
        flags.dir_cond_gen,
        'cond_gen_img_text_epoch_' + str(epoch).zfill(4) + '.png')
    fp_t_in_i_out = os.path.join(
        flags.dir_cond_gen,
        'cond_gen_text_img_epoch_' + str(epoch).zfill(4) + '.png')
    fp_t_in_t_out = os.path.join(
        flags.dir_cond_gen,
        'cond_gen_text_text_epoch_' + str(epoch).zfill(4) + '.png')
    plot_i_i = plot.create_fig(fp_i_in_i_out, rec_i_in_i_out, 10,
                               flags.save_plot_images)
    plot_i_t = plot.create_fig(fp_i_in_t_out, rec_i_in_t_out, 10,
                               flags.save_plot_images)
    plot_t_i = plot.create_fig(fp_t_in_i_out, rec_t_in_i_out, 10,
                               flags.save_plot_images)
    plot_t_t = plot.create_fig(fp_t_in_t_out, rec_t_in_t_out, 10,
                               flags.save_plot_images)
    img_cond = {
        'img_celeba': plot_i_i,
        'text': plot_i_t
    }
    text_cond = {
        'img_celeba': plot_t_i,
        'text': plot_t_t
    }
    plots = {
        'img_celeba': img_cond,
        'text': text_cond
    }
    return plots