Beispiel #1
0
    def sampling(self, split_dir):
        if cfg.TRAIN.NET_G == '':
            print('Error: the path for models is not found!')
        else:
            if split_dir == 'test':
                split_dir = 'valid'
            if cfg.GAN.B_DCGAN:
                netG = G_DCGAN()
            else:
                netG = G_NET()
            netG.apply(weights_init)
            netG.cuda()
            netG.eval()
            #
            text_encoder = RNN_ENCODER(self.n_words,
                                       nhidden=cfg.TEXT.EMBEDDING_DIM)
            state_dict = \
                torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)
            text_encoder.load_state_dict(state_dict)
            print('Load text encoder from:', cfg.TRAIN.NET_E)
            text_encoder = text_encoder.cuda()
            text_encoder.eval()

            batch_size = self.batch_size
            nz = cfg.GAN.Z_DIM
            noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True)
            noise = noise.cuda()

            model_dir = cfg.TRAIN.NET_G
            state_dict = \
                torch.load(model_dir, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', model_dir)

            # the path to save generated images
            s_tmp = model_dir[:model_dir.rfind('.pth')]
            save_dir = '%s/%s' % (s_tmp, split_dir)
            mkdir_p(save_dir)

            cnt = 0
            idx = 0  ###

            avg_ddva = 0
            for _ in range(1):
                for step, data in enumerate(self.data_loader, 0):
                    cnt += batch_size
                    if step % 100 == 0:
                        print('step: ', step)

                    captions, cap_lens, imperfect_captions, imperfect_cap_lens, misc = data

                    # Generate images for human-text ----------------------------------------------------------------
                    data_human = [captions, cap_lens, misc]

                    imgs, captions, cap_lens, class_ids, keys, wrong_caps,\
                                wrong_caps_len, wrong_cls_id= prepare_data(data_human)

                    hidden = text_encoder.init_hidden(batch_size)
                    words_embs, sent_emb = text_encoder(
                        captions, cap_lens, hidden)
                    words_embs, sent_emb = words_embs.detach(
                    ), sent_emb.detach()
                    mask = (captions == 0)
                    num_words = words_embs.size(2)
                    if mask.size(1) > num_words:
                        mask = mask[:, :num_words]

                    noise.data.normal_(0, 1)
                    fake_imgs, _, _, _ = netG(noise, sent_emb, words_embs,
                                              mask)

                    # Generate images for imperfect caption-text-------------------------------------------------------
                    data_imperfect = [
                        imperfect_captions, imperfect_cap_lens, misc
                    ]

                    imgs, imperfect_captions, imperfect_cap_lens, class_ids, imperfect_keys, wrong_caps,\
                                wrong_caps_len, wrong_cls_id = prepare_data(data_imperfect)

                    hidden = text_encoder.init_hidden(batch_size)
                    words_embs, sent_emb = text_encoder(
                        imperfect_captions, imperfect_cap_lens, hidden)
                    words_embs, sent_emb = words_embs.detach(
                    ), sent_emb.detach()
                    mask = (imperfect_captions == 0)
                    num_words = words_embs.size(2)
                    if mask.size(1) > num_words:
                        mask = mask[:, :num_words]

                    noise.data.normal_(0, 1)
                    imperfect_fake_imgs, _, _, _ = netG(
                        noise, sent_emb, words_embs, mask)

                    # Sort the results by keys to align ----------------------------------------------------------------
                    keys, captions, cap_lens, fake_imgs, _, _ = sort_by_keys(
                        keys, captions, cap_lens, fake_imgs, None, None)

                    imperfect_keys, imperfect_captions, imperfect_cap_lens, imperfect_fake_imgs, true_imgs, _ = \
                                sort_by_keys(imperfect_keys, imperfect_captions, imperfect_cap_lens, imperfect_fake_imgs,\
                                             imgs, None)

                    # Shift device for the imgs, target_imgs and imperfect_imgs------------------------------------------------
                    for i in range(len(imgs)):
                        imgs[i] = imgs[i].to(secondary_device)
                        imperfect_fake_imgs[i] = imperfect_fake_imgs[i].to(
                            secondary_device)
                        fake_imgs[i] = fake_imgs[i].to(secondary_device)

                    for j in range(batch_size):
                        s_tmp = '%s/single' % (save_dir)
                        folder = s_tmp[:s_tmp.rfind('/')]
                        if not os.path.isdir(folder):
                            print('Make a new folder: ', folder)
                            mkdir_p(folder)
                        k = -1
                        im = fake_imgs[k][j].data.cpu().numpy()
                        im = (im + 1.0) * 127.5
                        im = im.astype(np.uint8)
                        im = np.transpose(im, (1, 2, 0))

                        cap_im = imperfect_fake_imgs[k][j].data.cpu().numpy()
                        cap_im = (cap_im + 1.0) * 127.5
                        cap_im = cap_im.astype(np.uint8)
                        cap_im = np.transpose(cap_im, (1, 2, 0))

                        # Uncomment to scale true image
                        true_im = true_imgs[k][j].data.cpu().numpy()
                        true_im = (true_im + 1.0) * 127.5
                        true_im = true_im.astype(np.uint8)
                        true_im = np.transpose(true_im, (1, 2, 0))

                        # Uncomment to save images.
                        #true_im = Image.fromarray(true_im)
                        #fullpath = '%s_true_s%d.png' % (s_tmp, idx)
                        #true_im.save(fullpath)
                        im = Image.fromarray(im)
                        fullpath = '%s_s%d.png' % (s_tmp, idx)
                        im.save(fullpath)
                        #cap_im = Image.fromarray(cap_im)
                        #fullpath = '%s_imperfect_s%d.png' % (s_tmp, idx)
                        idx = idx + 1
                        #cap_im.save(fullpath)

                    neg_ddva = negative_ddva(
                        imperfect_fake_imgs,
                        imgs,
                        fake_imgs,
                        reduce='mean',
                        final_only=True).data.cpu().numpy()
                    avg_ddva += neg_ddva * (-1)

                    #text_caps = [[self.ixtoword[word] for word in sent if word!=0] for sent in captions.tolist()]

                    #imperfect_text_caps = [[self.ixtoword[word] for word in sent if word!=0] for sent in
                    #                       imperfect_captions.tolist()]

                    print(step)
            avg_ddva = avg_ddva / (step + 1)
            print('\n\nAvg_DDVA: ', avg_ddva)
Beispiel #2
0
    def train(self):
        text_encoder, image_encoder, netG, target_netG, netsD, start_epoch = self.build_models()
        avg_param_G = copy_G_params(netG)
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        gen_iterations = 0
        sliding_window = []
        
        # gen_iterations = start_epoch * self.num_batches
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:
                # reset requires_grad to be trainable for all Ds
                # self.set_requires_grad_value(netsD, True)

                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                data = data_iter.next()
                
                captions, cap_lens, imperfect_captions, imperfect_cap_lens, misc = data
                
                # Generate images for human-text ----------------------------------------------------------------
                data_human = [captions, cap_lens, misc]
                
                imgs, captions, cap_lens, class_ids, keys = prepare_data(data_human)

                hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs, mask, cap_lens)

                # Generate images for imperfect caption-text-------------------------------------------------------

                data_imperfect = [imperfect_captions, imperfect_cap_lens, misc]
                    
                imgs, imperfect_captions, imperfect_cap_lens, i_class_ids, imperfect_keys = prepare_data(data_imperfect)
                
                i_hidden = text_encoder.init_hidden(batch_size)
                i_words_embs, i_sent_emb = text_encoder(imperfect_captions, imperfect_cap_lens, i_hidden)
                i_words_embs, i_sent_emb = i_words_embs.detach(), i_sent_emb.detach()
                i_mask = (imperfect_captions == 0)
                i_num_words = i_words_embs.size(2)
                
                if i_mask.size(1) > i_num_words:
                    i_mask = i_mask[:, :i_num_words]
                    
                # Move tensors to the secondary device.
                #noise  = noise.to(secondary_device) # IMPORTANT! We are reusing the same noise.
                #i_sent_emb = i_sent_emb.to(secondary_device)
                #i_words_embs = i_words_embs.to(secondary_device)
                #i_mask = i_mask.to(secondary_device)
                
                # Generate images.
                imperfect_fake_imgs, _, _, _ = target_netG(noise, i_sent_emb, i_words_embs, i_mask) 
                
                # Sort the results by keys to align ------------------------------------------------------------------------
                bag = [sent_emb, real_labels, fake_labels, words_embs, class_ids]
                
                keys, captions, cap_lens, fake_imgs, _, sorted_bag = sort_by_keys(keys, captions, cap_lens, fake_imgs,\
                                                                                  None, bag)
                    
                sent_emb, real_labels, fake_labels, words_embs, class_ids = \
                            sorted_bag
                 
                imperfect_keys, imperfect_captions, imperfect_cap_lens, imperfect_fake_imgs, imgs, _ = \
                            sort_by_keys(imperfect_keys, imperfect_captions, imperfect_cap_lens, imperfect_fake_imgs, imgs,None)
                    
                #-----------------------------------------------------------------------------------------------------------
 
                #######################################################
                # (3) Update D network
                ######################################################
                errD_total = 0
                D_logs = ''
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    errD, log = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                              sent_emb, real_labels, fake_labels)
                    # backward and update parameters
                    errD.backward()
                    optimizersD[i].step()
                    errD_total += errD
                    D_logs += 'errD%d: %.2f ' % (i, errD.item())
                    D_logs += log

                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G
                step += 1
                gen_iterations += 1

                # do not need to compute gradient for Ds
                # self.set_requires_grad_value(netsD, False)
                netG.zero_grad()
                errG_total, G_logs = \
                    generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                                   words_embs, sent_emb, match_labels, cap_lens, class_ids)
                kl_loss = KL_loss(mu, logvar)
                errG_total += kl_loss
                G_logs += 'kl_loss: %.2f ' % kl_loss.item()
                
                # Shift device -----------------------------------------------------
                #for i in range(len(imgs)):
                #    imgs[i] = imgs[i].to(secondary_device)
                #    fake_imgs[i] = fake_imgs[i].to(secondary_device)
                   
                print('Discriminator loss: ', errG_total)
                
                # Compute and add ddva loss ---------------------------------------------------------------------
                neg_ddva = negative_ddva(imperfect_fake_imgs, imgs, fake_imgs)
                neg_ddva *= 10. # Scale so that the ddva score is not overwhelmed by other losses.
                errG_total += neg_ddva.to(cfg.GPU_ID)
                #G_logs += 'negative_ddva_loss: %.2f ' % neg_ddva
                #------------------------------------------------------------------------------------------------
                
                # backward and update parameters
                errG_total.backward()
                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                if len(sliding_window)==100:
                	del sliding_window[0]
                sliding_window.append(neg_ddva)
                sliding_avg_ddva =  sum(sliding_window)/len(sliding_window)

                print('sliding_window avg NEG DDVA: ',sliding_avg_ddva)
                print('Negative ddva: ', neg_ddva)
                
                #if gen_iterations % 100 == 0:
                #    print('Epoch [{}/{}] Step [{}/{}]'.format(epoch, self.max_epoch, step,
                                                              self.num_batches) + ' ' + D_logs + ' ' + G_logs)
                
                # Copy parameters to the target network.
                #if gen_iterations % 4 == 0:
                load_params(target_netG, copy_G_params(netG))
                # Disable training in the target network:
                for p in target_netG.parameters():
                    p.requires_grad = False
                    
            end_t = time.time()

            #print('''[%d/%d] Loss_D: %.2f Loss_G: %.2f Time: %.2fs''' % (
            #    epoch, self.max_epoch, errD_total.item(), errG_total.item(), end_t - start_t))
            #print('-' * 89)
            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:  # and epoch != 0:
                self.save_model(netG, avg_param_G, netsD, epoch)
Beispiel #3
0
    def train(self):
        text_encoder, image_encoder, netG, target_netG, netsD, start_epoch, style_loss = self.build_models(
        )
        avg_param_G = copy_G_params(netG)
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        gen_iterations = 0

        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:

                data = data_iter.next()

                captions, cap_lens, imperfect_captions, imperfect_cap_lens, misc = data

                # Generate images for human-text ----------------------------------------------------------------
                data_human = [captions, cap_lens, misc]

                imgs, captions, cap_lens, class_ids, keys, wrong_caps, \
                                wrong_caps_len, wrong_cls_id = prepare_data(data_human)

                hidden = text_encoder.init_hidden(batch_size)
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()

                # wrong word and sentence embeddings
                w_words_embs, w_sent_emb = text_encoder(
                    wrong_caps, wrong_caps_len, hidden)
                w_words_embs, w_sent_emb = w_words_embs.detach(
                ), w_sent_emb.detach()

                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                noise.data.normal_(0, 1)
                fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs,
                                                mask)

                # Generate images for imperfect caption-text-------------------------------------------------------

                data_imperfect = [imperfect_captions, imperfect_cap_lens, misc]

                imgs, imperfect_captions, imperfect_cap_lens, i_class_ids, imperfect_keys, i_wrong_caps,\
                            i_wrong_caps_len, i_wrong_cls_id = prepare_data(data_imperfect)

                i_hidden = text_encoder.init_hidden(batch_size)
                i_words_embs, i_sent_emb = text_encoder(
                    imperfect_captions, imperfect_cap_lens, i_hidden)
                i_words_embs, i_sent_emb = i_words_embs.detach(
                ), i_sent_emb.detach()
                i_mask = (imperfect_captions == 0)
                i_num_words = i_words_embs.size(2)

                if i_mask.size(1) > i_num_words:
                    i_mask = i_mask[:, :i_num_words]

                # Move tensors to the secondary device.
                noise = noise.to(secondary_device
                                 )  # IMPORTANT! We are reusing the same noise.
                i_sent_emb = i_sent_emb.to(secondary_device)
                i_words_embs = i_words_embs.to(secondary_device)
                i_mask = i_mask.to(secondary_device)

                # Generate images.
                imperfect_fake_imgs, _, _, _ = target_netG(
                    noise, i_sent_emb, i_words_embs, i_mask)

                # Sort the results by keys to align ------------------------------------------------------------------------
                bag = [
                    sent_emb, real_labels, fake_labels, words_embs, class_ids,
                    w_words_embs, wrong_caps_len, wrong_cls_id
                ]

                keys, captions, cap_lens, fake_imgs, _, sorted_bag = sort_by_keys(keys, captions, cap_lens, fake_imgs,\
                                                                                  None, bag)

                sent_emb, real_labels, fake_labels, words_embs, class_ids, w_words_embs, wrong_caps_len, wrong_cls_id = \
                            sorted_bag

                imperfect_keys, imperfect_captions, imperfect_cap_lens, imperfect_fake_imgs, imgs, _ = \
                            sort_by_keys(imperfect_keys, imperfect_captions, imperfect_cap_lens, imperfect_fake_imgs, imgs,None)

                #-----------------------------------------------------------------------------------------------------------

                errD_total = 0
                D_logs = ''
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                              sent_emb, real_labels,
                                              fake_labels, words_embs,
                                              cap_lens, image_encoder,
                                              class_ids, w_words_embs,
                                              wrong_caps_len, wrong_cls_id)
                    # backward and update parameters
                    errD.backward(retain_graph=True)
                    optimizersD[i].step()
                    errD_total += errD
                    D_logs += 'errD%d: %.2f ' % (i, errD)

                step += 1
                gen_iterations += 1

                netG.zero_grad()
                errG_total, G_logs = \
                    generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                                   words_embs, sent_emb, match_labels, cap_lens, class_ids, style_loss, imgs)
                kl_loss = KL_loss(mu, logvar)

                errG_total += kl_loss

                G_logs += 'kl_loss: %.2f ' % kl_loss

                # Shift device for the imgs and target_imgs.-----------------------------------------------------
                for i in range(len(imgs)):
                    imgs[i] = imgs[i].to(secondary_device)
                    fake_imgs[i] = fake_imgs[i].to(secondary_device)

                # Compute and add ddva loss ---------------------------------------------------------------------
                neg_ddva = negative_ddva(imperfect_fake_imgs, imgs, fake_imgs)
                neg_ddva *= 10.  # Scale so that the ddva score is not overwhelmed by other losses.
                errG_total += neg_ddva.to(cfg.GPU_ID)
                G_logs += 'negative_ddva_loss: %.2f ' % neg_ddva
                #------------------------------------------------------------------------------------------------

                errG_total.backward()

                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                if gen_iterations % 100 == 0:
                    print(D_logs + '\n' + G_logs)

                # Copy parameters to the target network.
                if gen_iterations % 20 == 0:
                    load_params(target_netG, copy_G_params(netG))

            end_t = time.time()

            print('''[%d/%d][%d]
                  Loss_D: %.2f Loss_G: %.2f neg_ddva: %.2f Time: %.2fs''' %
                  (epoch, self.max_epoch, self.num_batches, errD_total,
                   errG_total, neg_ddva, end_t - start_t))

            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:
                self.save_model(netG, avg_param_G, netsD, epoch)

        self.save_model(netG, avg_param_G, netsD, self.max_epoch)
Beispiel #4
0
    def train(self):
        text_encoder, image_encoder, netG, netsD, start_epoch = self.build_models(
        )
        avg_param_G = copy_G_params(netG)
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        gen_iterations = 0

        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:

                data = data_iter.next()

                captions, cap_lens, imperfect_captions, imperfect_cap_lens, misc = data

                # Generate images for human-text ----------------------------------------------------------------
                data_human = [captions, cap_lens, misc]

                imgs, captions, cap_lens, class_ids, keys, wrong_caps, \
                                wrong_caps_len, wrong_cls_id = prepare_data(data_human)

                hidden = text_encoder.init_hidden(batch_size)
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()

                # wrong word and sentence embeddings
                w_words_embs, w_sent_emb = text_encoder(
                    wrong_caps, wrong_caps_len, hidden)
                w_words_embs, w_sent_emb = w_words_embs.detach(
                ), w_sent_emb.detach()

                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                noise.data.normal_(0, 1)
                fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs,
                                                mask)

                # Generate images for imperfect caption-text-------------------------------------------------------

                data_imperfect = [imperfect_captions, imperfect_cap_lens, misc]

                imgs, imperfect_captions, imperfect_cap_lens, i_class_ids, imperfect_keys, i_wrong_caps,\
                            i_wrong_caps_len, i_wrong_cls_id = prepare_data(data_imperfect)

                i_hidden = text_encoder.init_hidden(batch_size)
                i_words_embs, i_sent_emb = text_encoder(
                    imperfect_captions, imperfect_cap_lens, i_hidden)
                i_words_embs, i_sent_emb = i_words_embs.detach(
                ), i_sent_emb.detach()
                i_mask = (imperfect_captions == 0)
                i_num_words = i_words_embs.size(2)

                if i_mask.size(1) > i_num_words:
                    i_mask = i_mask[:, :i_num_words]

                noise.data.normal_(0, 1)
                imperfect_fake_imgs, _, _, _ = netG(noise, i_sent_emb,
                                                    i_words_embs, i_mask)

                # Sort the results by keys to align ------------------------------------------------------------------------
                bag = [
                    sent_emb, real_labels, fake_labels, words_embs, class_ids,
                    w_words_embs, wrong_caps_len, wrong_cls_id
                ]

                keys, captions, cap_lens, fake_imgs, _, sorted_bag = sort_by_keys(keys, captions, cap_lens, fake_imgs,\
                                                                                  None, bag)

                sent_emb, real_labels, fake_labels, words_embs, class_ids, w_words_embs, wrong_caps_len, wrong_cls_id = \
                            sorted_bag

                imperfect_keys, imperfect_captions, imperfect_cap_lens, imperfect_fake_imgs, imgs, _ = \
                            sort_by_keys(imperfect_keys, imperfect_captions, imperfect_cap_lens, imperfect_fake_imgs, imgs,None)

                #-----------------------------------------------------------------------------------------------------------

                errD_total = 0
                D_logs = ''
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                              sent_emb, real_labels,
                                              fake_labels, words_embs,
                                              cap_lens, image_encoder,
                                              class_ids, w_words_embs,
                                              wrong_caps_len, wrong_cls_id)
                    # backward and update parameters
                    errD.backward(retain_graph=True)
                    optimizersD[i].step()
                    errD_total += errD
                    D_logs += 'errD%d: %.2f ' % (i, errD)

                step += 1
                gen_iterations += 1

                netG.zero_grad()
                errG_total, G_logs = \
                    generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                                   words_embs, sent_emb, match_labels, cap_lens, class_ids, imgs)
                kl_loss = KL_loss(mu, logvar)

                errG_total += kl_loss

                G_logs += 'kl_loss: %.2f ' % kl_loss

                # Compute and add ddva loss ---------------------------------------------------------------------
                neg_ddva = negative_ddva(imperfect_fake_imgs, imgs, fake_imgs)

                errG_total += neg_ddva

                G_logs += 'negative_ddva_loss: %.2f ' % neg_ddva

                errG_total.backward()

                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                if gen_iterations % 100 == 0:
                    print(D_logs + '\n' + G_logs)

                # save images
                #if gen_iterations % 1000 == 0:
                #    backup_para = copy_G_params(netG)
                #    load_params(netG, avg_param_G)
                #    self.save_img_results(netG, fixed_noise, sent_emb,
                #                          words_embs, mask, image_encoder,
                #                          captions, cap_lens, epoch, name='average')
                #    load_params(netG, backup_para)
                break

            end_t = time.time()

            print('''[%d/%d][%d]
                  Loss_D: %.2f Loss_G: %.2f Time: %.2fs''' %
                  (epoch, self.max_epoch, self.num_batches, errD_total,
                   errG_total, end_t - start_t))