Example #1
0
    def _generator_train_step(self, sent_emb, words_emb, cap_lens, step,
                              epoch):

        noise = torch.randn(self.batch_size, self.nz).to(self.device)
        fake_data, mu, logvar = self.netG(noise, sent_emb)
        fake_data = fake_data.to(self.device)

        if False:
            region_features, cnn_code = self.image_encoder(fake_data)

            match_labels = Variable(torch.LongTensor(range(self.batch_size)))
            s_loss0, s_loss1 = sent_loss(cnn_code, sent_emb, match_labels,
                                         None, self.batch_size)
            s_loss = (s_loss0 + s_loss1) * 1
            w_loss0, w_loss1, _ = words_loss(region_features, words_emb,
                                             match_labels, cap_lens, None,
                                             self.batch_size)
            w_loss = (w_loss0 + w_loss1) * 1

            g_out = self.netD(fake_data, sent_emb)
            kl_loss = KL_loss(mu, logvar)
            loss = self.gan_loss(g_out, "gen") + kl_loss + s_loss + w_loss

            self.exp.log_metric('kl_loss',
                                kl_loss.item(),
                                step=step,
                                epoch=epoch)
            self.exp.log_metric('s_loss',
                                s_loss.item(),
                                step=step,
                                epoch=epoch)

        else:
            g_out = self.netD(fake_data, sent_emb)
            kl_loss = KL_loss(mu, logvar)
            loss = self.gan_loss(g_out, "gen") + kl_loss

        self.optG.zero_grad()
        loss.backward()
        self.optG.step()
        return loss.item()
    def train(self):
        text_encoder, image_encoder, netG, netsD, start_epoch = self.build_models()
        avg_param_G = copy_G_params(netG)
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        gen_iterations = 0
        # gen_iterations = start_epoch * self.num_batches
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:
                # reset requires_grad to be trainable for all Ds
                # self.set_requires_grad_value(netsD, True)

                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                data = data_iter.next()
                imgs, captions, cap_lens, class_ids, keys, transformation_matrices, label_one_hot = prepare_data(data)
                transf_matrices = transformation_matrices[0]
                transf_matrices_inv = transformation_matrices[1]

                hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                inputs = (noise, sent_emb, words_embs, mask, transf_matrices_inv, label_one_hot)
                fake_imgs, _, mu, logvar = nn.parallel.data_parallel(netG, inputs, self.gpus)

                #######################################################
                # (3) Update D network
                ######################################################
                errD_total = 0
                D_logs = ''
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    if i == 0:
                        errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                                  sent_emb, real_labels, fake_labels, self.gpus,
                                                  local_labels=label_one_hot, transf_matrices=transf_matrices,
                                                  transf_matrices_inv=transf_matrices_inv)
                    else:
                        errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                                  sent_emb, real_labels, fake_labels, self.gpus)

                    # backward and update parameters
                    errD.backward()
                    optimizersD[i].step()
                    errD_total += errD
                    D_logs += 'errD%d: %.2f ' % (i, errD.item())

                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G
                step += 1
                gen_iterations += 1

                # do not need to compute gradient for Ds
                # self.set_requires_grad_value(netsD, False)
                netG.zero_grad()
                errG_total, G_logs = \
                    generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                                   words_embs, sent_emb, match_labels, cap_lens, class_ids, self.gpus,
                                   local_labels=label_one_hot, transf_matrices=transf_matrices,
                                   transf_matrices_inv=transf_matrices_inv)
                kl_loss = KL_loss(mu, logvar)
                errG_total += kl_loss
                G_logs += 'kl_loss: %.2f ' % kl_loss.item()
                # backward and update parameters
                errG_total.backward()
                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                # save images
                if gen_iterations % 1000 == 0:
                    print(D_logs + '\n' + G_logs)

                    backup_para = copy_G_params(netG)
                    load_params(netG, avg_param_G)
                    self.save_img_results(netG, fixed_noise, sent_emb,
                                          words_embs, mask, image_encoder,
                                          captions, cap_lens, epoch,  transf_matrices_inv,
                                          label_one_hot, name='average')
                    load_params(netG, backup_para)
            end_t = time.time()

            print('''[%d/%d][%d]
                  Loss_D: %.2f Loss_G: %.2f Time: %.2fs'''
                  % (epoch, self.max_epoch, self.num_batches,
                     errD_total.item(), errG_total.item(),
                     end_t - start_t))

            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:  # and epoch != 0:
                self.save_model(netG, avg_param_G, netsD, optimizerG, optimizersD, epoch)

        self.save_model(netG, avg_param_G, netsD, optimizerG, optimizersD, epoch)
Example #3
0
    def eval(self):
        text_encoder, image_encoder, netG, netsD, start_epoch = self.build_models()
        avg_param_G = copy_G_params(netG)
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        gen_iterations = 0
        start_t = time.time()

        data_iter = iter(self.data_loader)
        step = 0
        while step < self.num_batches:
            # reset requires_grad to be trainable for all Ds
            # self.set_requires_grad_value(netsD, True)

            ######################################################
            # (1) Prepare training data and Compute text embeddings
            ######################################################
            data = data_iter.next()
            imgs, captions, cap_lens, class_ids, keys = prepare_data(data)

            hidden = text_encoder.init_hidden(batch_size)
            # words_embs: batch_size x nef x seq_len
            # sent_emb: batch_size x nef
            words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
            words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
            mask = (captions == 0)
            num_words = words_embs.size(2)
            if mask.size(1) > num_words:
                mask = mask[:, :num_words]

            #######################################################
            # (2) Generate fake images
            ######################################################
            noise.data.normal_(0, 1)
            fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs, mask)

            #######################################################
            # (3) Update D network
            ######################################################
            errD_total = 0
            D_logs = ''
            for i in range(len(netsD)):
                netsD[i].zero_grad()
                errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                          sent_emb, real_labels, fake_labels)
                errD_total += errD

	        D_logs += 'errD%d: %.2f ' % (i, errD.item())

            #######################################################
            # (4) Update G network: maximize log(D(G(z)))
            ######################################################
            # compute total loss for training G
            step += 1
            gen_iterations += 1

            netG.zero_grad()
            errG_total, G_logs = \
                generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                               words_embs, sent_emb, match_labels, cap_lens, class_ids)
            kl_loss = KL_loss(mu, logvar)
            errG_total += kl_loss
	    G_logs += 'kl_loss: %.2f ' % kl_loss.item()
            print(D_logs + '\n' + G_logs)

        end_t = time.time()

	print('''[%d/%d][%d]
              Loss_D: %.2f Loss_G: %.2f Time: %.2fs'''
              % (epoch, self.max_epoch, self.num_batches,
                 errD_total.item(), errG_total.item(),
                 end_t - start_t))
Example #4
0
    def train(self):
        text_encoder, image_encoder, netG, netsD, start_epoch = self.build_models()
        avg_param_G = copy_G_params(netG)
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        gen_iterations = 0
        loss_dict = {}
        # gen_iterations = start_epoch * self.num_batches
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:
                # reset requires_grad to be trainable for all Ds
                # self.set_requires_grad_value(netsD, True)

                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                data = data_iter.next()
                imgs, captions, cap_lens, class_ids, keys = prepare_data(data)

                hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs, mask, cap_lens)

                #######################################################
                # (3) Update D network
                ######################################################
                errD_total = 0
                D_logs = ''
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    errD, log, d_dict = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                                           sent_emb, real_labels, fake_labels)
                    # backward and update parameters
                    errD.backward()
                    optimizersD[i].step()
                    errD_total += errD
                    D_logs += 'errD%d: %.2f ' % (i, errD.item())
                    D_logs += log
                    loss_dict['Real_Acc_{}'.format(i)] = d_dict['Real_Acc']
                    loss_dict['Fake_Acc_{}'.format(i)] = d_dict['Fake_Acc']
                    loss_dict['errD_{}'.format(i)] = errD.item()

                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G
                step += 1
                gen_iterations += 1

                # do not need to compute gradient for Ds
                # self.set_requires_grad_value(netsD, False)
                netG.zero_grad()
                errG_total, G_logs, g_dict = \
                    generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                                   words_embs, sent_emb, match_labels, cap_lens, class_ids)
                kl_loss = KL_loss(mu, logvar)
                errG_total += kl_loss
                G_logs += 'kl_loss: %.2f ' % kl_loss.item()
                loss_dict.update(g_dict)
                loss_dict['kl_loss'] = kl_loss.item()
                # backward and update parameters
                errG_total.backward()
                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                if gen_iterations % 100 == 0:
                    print('Epoch [{}/{}] Step [{}/{}]'.format(epoch, self.max_epoch, step,
                                                              self.num_batches) + ' ' + D_logs + ' ' + G_logs)
                if self.logger:
                    self.logger.log(loss_dict)
                # save images
                if gen_iterations % 10000 == 0:
                    backup_para = copy_G_params(netG)
                    load_params(netG, avg_param_G)
                    self.save_img_results(netG, fixed_noise, sent_emb, words_embs, mask, image_encoder,
                                          captions, cap_lens, gen_iterations)
                    load_params(netG, backup_para)
                    #
                    # self.save_img_results(netG, fixed_noise, sent_emb,
                    #                       words_embs, mask, image_encoder,
                    #                       captions, cap_lens,
                    #                       epoch, name='current')
                # if gen_iterations % 1000 == 0:
                #    time.sleep(30)
                # if gen_iterations % 10000 == 0:
                #    time.sleep(160)
            end_t = time.time()

            print('''[%d/%d] Loss_D: %.2f Loss_G: %.2f Time: %.2fs''' % (
                epoch, self.max_epoch, errD_total.item(), errG_total.item(), end_t - start_t))
            print('-' * 89)
            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:  # and epoch != 0:
                self.save_model(netG, avg_param_G, netsD, epoch)

        self.save_model(netG, avg_param_G, netsD, self.max_epoch)
Example #5
0
    def train(self):
        wandb.init(name=cfg.EXP_NAME, project='AttnGAN', config=cfg, dir='../logs')

        text_encoder, image_encoder, netG, netsD, start_epoch = self.build_models()
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)
        text_encoder, image_encoder, netG, netsD, optimizerG, optimizersD =  \
            self.apply_apex(text_encoder, image_encoder, netG, netsD, optimizerG, optimizersD)
        # add watch
        wandb.watch(netG)
        for D in netsD:
            wandb.watch(D)

        avg_param_G = copy_G_params(netG)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        log_dict = {}
        gen_iterations = 0
        # gen_iterations = start_epoch * self.num_batches
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:
                # reset requires_grad to be trainable for all Ds
                # self.set_requires_grad_value(netsD, True)

                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                data = data_iter.next()
                imgs, captions, cap_lens, class_ids, keys = prepare_data(data)

                hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs, mask)

                #######################################################
                # (3) Update D network
                ######################################################
                errD_total = 0
                D_logs = ''
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                              sent_emb, real_labels, fake_labels)
                    # backward and update parameters
                    if cfg.APEX:
                        from apex import amp
                        with amp.scale_loss(errD, optimizersD[i], loss_id=i) as errD_scaled:
                            errD_scaled.backward()
                    else:
                        errD.backward()
                    optimizersD[i].step()
                    errD_total += errD
                    D_logs += 'errD%d: %.2f ' % (i, errD.item())
                    log_name = 'errD_{}'.format(i)
                    log_dict[log_name] = errD.item()

                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G
                step += 1
                gen_iterations += 1

                # do not need to compute gradient for Ds
                # self.set_requires_grad_value(netsD, False)
                netG.zero_grad()
                errG_total, G_logs, G_log_dict = \
                    generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                                   words_embs, sent_emb, match_labels, cap_lens, class_ids)
                kl_loss = KL_loss(mu, logvar)
                errG_total += kl_loss
                G_logs += 'kl_loss: %.2f ' % kl_loss.item()
                log_dict.update(G_log_dict)
                log_dict['kl_loss'] = kl_loss.item()
                # backward and update parameters
                if cfg.APEX:
                    from apex import amp
                    with amp.scale_loss(errG_total, optimizerG, loss_id=len(netsD)) as errG_scaled:
                        errG_scaled.backward()
                else:
                    errG_total.backward()
                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                wandb.log(log_dict)
                if gen_iterations % 100 == 0:
                    print(D_logs + '\n' + G_logs)
                    wandb.save('logs.ckpt')
                # save images
                if gen_iterations % 1000 == 0:
                    backup_para = copy_G_params(netG)
                    load_params(netG, avg_param_G)
                    self.save_img_results(netG, fixed_noise, sent_emb,
                                          words_embs, mask, image_encoder,
                                          captions, cap_lens, epoch, name='average')
                    load_params(netG, backup_para)
                    #
                    # self.save_img_results(netG, fixed_noise, sent_emb,
                    #                       words_embs, mask, image_encoder,
                    #                       captions, cap_lens,
                    #                       epoch, name='current')
            end_t = time.time()

            print('''[%d/%d][%d]
                  Loss_D: %.2f Loss_G: %.2f Time: %.2fs'''
                  % (epoch, self.max_epoch, self.num_batches,
                     errD_total.item(), errG_total.item(),
                     end_t - start_t))

            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:  # and epoch != 0:
                self.save_model(netG, avg_param_G, netsD, epoch)

        self.save_model(netG, avg_param_G, netsD, self.max_epoch)
Example #6
0
    def train(self):
        text_encoder, image_encoder, netG, netsPatD, netsShpD, netObjSSD, netObjLSD, \
            start_epoch = self.build_models()
        avg_param_G = copy_G_params(netG)

        optimizerG, optimizersPatD, optimizersShpD, optimizerObjSSD, optimizerObjLSD = \
            self.define_optimizers(netG, netsPatD, netsShpD, netObjSSD, netObjLSD)

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        match_labels = self.prepare_labels()
        clabels_emb = self.prepare_cat_emb()

        gen_iterations = 0
        # gen_iterations = start_epoch * self.num_batches
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            predictions = []
            while step < self.num_batches:
                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                data = data_iter.next()
                imgs, captions, glove_captions, cap_lens, hmaps, rois, fm_rois, \
                    num_rois, bt_masks, fm_bt_masks, class_ids, keys = prepare_data(data)

                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                max_len = int(torch.max(cap_lens))
                words_embs, sent_emb = text_encoder(captions, cap_lens,
                                                    max_len)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()

                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                # glove_words_embs: batch_size x 50 (glove dim) x seq_len
                glove_words_embs = self.glove_emb(glove_captions.view(-1))
                glove_words_embs = glove_words_embs.detach().view(
                    glove_captions.size(0), glove_captions.size(1), -1)
                glove_words_embs = glove_words_embs[:, :num_words].transpose(
                    1, 2)

                # clabels_feat: batch x 50 (glove dim) x max_num_roi x 1
                clabels_feat = form_clabels_feat(clabels_emb, rois[0],
                                                 num_rois)

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                glb_max_num_roi = int(torch.max(num_rois))
                fake_imgs, bt_c_codes, _, _, mu, logvar = netG(
                    noise, sent_emb, words_embs, glove_words_embs,
                    clabels_feat, mask, hmaps, rois, fm_rois, num_rois,
                    bt_masks, fm_bt_masks, glb_max_num_roi)
                bt_c_codes = [bt_c_code.detach() for bt_c_code in bt_c_codes]

                #######################################################
                # (3-1) Update PatD network
                ######################################################
                errPatD_total = 0
                PatD_logs = ''
                for i in range(len(netsPatD)):
                    netsPatD[i].zero_grad()
                    errPatD = patD_loss(netsPatD[i], imgs[i], fake_imgs[i],
                                        sent_emb)
                    errPatD.backward()
                    optimizersPatD[i].step()
                    errPatD_total += errPatD
                    PatD_logs += 'errPatD%d: %.2f ' % (i, errPatD.item())

                #######################################################
                # (3-2) Update ShpD network
                ######################################################
                errShpD_total = 0
                ShpD_logs = ''
                for i in range(len(netsShpD)):
                    netsShpD[i].zero_grad()
                    hmap = hmaps[i]
                    roi = rois[i]
                    errShpD = shpD_loss(netsShpD[i], imgs[i], fake_imgs[i],
                                        hmap, roi, num_rois)
                    errShpD.backward()
                    optimizersShpD[i].step()
                    errShpD_total += errShpD
                    ShpD_logs += 'errShpD%d: %.2f ' % (i, errShpD.item())

                #######################################################
                # (3-3) Update ObjSSD network
                ######################################################
                netObjSSD.zero_grad()
                errObjSSD = objD_loss(netObjSSD, imgs[-1], fake_imgs[-1],
                                      hmaps[-1], clabels_emb, bt_c_codes[-1],
                                      rois[0], num_rois)
                if float(errObjSSD) > 0:
                    errObjSSD.backward()
                    optimizerObjSSD.step()
                    ObjSSD_logs = 'errSSACD: %.2f ' % (errObjSSD.item())

                #######################################################
                # (3-4) Update ObjLSD network
                ######################################################
                netObjLSD.zero_grad()
                errObjLSD = objD_loss(netObjLSD,
                                      imgs[-1],
                                      fake_imgs[-1],
                                      hmaps[-1],
                                      clabels_emb,
                                      bt_c_codes[-1],
                                      fm_rois,
                                      num_rois,
                                      is_large_scale=True)
                if float(errObjLSD) > 0:
                    errObjLSD.backward()
                    optimizerObjLSD.step()
                    ObjLSD_logs = 'errObjLSD: %.2f ' % (errObjLSD.item())

                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G
                netG.zero_grad()
                errG_total, G_logs = \
                    G_loss(netsPatD, netsShpD, netObjSSD, netObjLSD, image_encoder, fake_imgs,
                                   hmaps, words_embs, sent_emb, clabels_emb, bt_c_codes[-1],
                                   match_labels, cap_lens, class_ids, rois[0], fm_rois, num_rois)

                kl_loss = KL_loss(mu, logvar)
                errG_total += kl_loss
                G_logs += 'kl_loss: %.2f ' % kl_loss.item()
                # backward and update parameters
                errG_total.backward()
                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                #######################################################
                # (5) Print and display
                ######################################################
                images = fake_imgs[-1].detach()
                pred = self.inception_model(images)
                predictions.append(pred.data.cpu().numpy())

                step += 1
                gen_iterations += 1

                if gen_iterations % self.print_interval == 0:
                    print('[%d/%d][%d]' %
                          (epoch, self.max_epoch, gen_iterations) + '\n' +
                          PatD_logs + '\n' + ShpD_logs + '\n' + ObjSSD_logs +
                          '\n' + ObjLSD_logs + '\n' + G_logs)
                # save images
                if gen_iterations % self.display_interval == 0:
                    backup_para = copy_G_params(netG)
                    load_params(netG, avg_param_G)
                    self.save_img_results(netG,
                                          fixed_noise,
                                          sent_emb,
                                          words_embs,
                                          glove_words_embs,
                                          clabels_feat,
                                          mask,
                                          hmaps,
                                          rois,
                                          fm_rois,
                                          num_rois,
                                          bt_masks,
                                          fm_bt_masks,
                                          image_encoder,
                                          captions,
                                          cap_lens,
                                          epoch,
                                          name='average')
                    load_params(netG, backup_para)

            end_t = time.time()

            print(
                '''[%d/%d][%d]
                  Loss_PatD: %.2f Loss_ShpD: %.2f Loss_ObjSSD: %.2f Loss_ObjLSD: %.2f Loss_G: %.2f Time: %.2fs'''
                %
                (epoch, self.max_epoch, self.num_batches, errPatD_total.item(),
                 errShpD_total.item(), errObjSSD.item(), errObjLSD.item(),
                 errG_total.item(), end_t - start_t))

            predictions = np.concatenate(predictions, 0)
            mean, std = compute_inception_score(predictions,
                                                min(10, self.batch_size))
            mean_conf, std_conf = \
                negative_log_posterior_probability(predictions, min(10, self.batch_size))

            fullpath = '%s/scores_%d.txt' % (self.score_dir, epoch)
            with open(fullpath, 'w') as fp:
                fp.write('mean, std, mean_conf, std_conf \n')
                fp.write('%f, %f, %f, %f' % (mean, std, mean_conf, std_conf))

            print('inception_score: mean, std, mean_conf, std_conf')
            print('inception_score: %f, %f, %f, %f' %
                  (mean, std, mean_conf, std_conf))

            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:  # and epoch != 0:
                self.save_model(netG, avg_param_G, netsPatD, netsShpD,
                                netObjSSD, netObjLSD, epoch)

        self.save_model(netG, avg_param_G, netsPatD, netsShpD, netObjSSD,
                        netObjLSD, self.max_epoch)
    def train(self):
        netsG, netsD, inception_model, classifiers, start_epoch = self.build_models(
        )
        avg_params_G = []
        for i in range(len(netsG)):
            avg_params_G.append(copy_G_params(netsG[i]))
        optimizersG, optimizersD, optimizersC = self.define_optimizers(
            netsG, netsD, classifiers)
        real_labels, fake_labels = self.prepare_labels()
        writer = SummaryWriter(self.args.run_dir)

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        gen_iterations = 0
        # gen_iterations = start_epoch * self.num_batches
        for epoch in range(start_epoch, self.max_epoch):
            print("epoch: {}/{}".format(epoch, self.max_epoch))
            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:
                print("step:{}/{} {:.2f}%".format(
                    step, self.num_batches, step / self.num_batches * 100))
                """
                if(step%self.display_interval==0):
                    print("step:{}/{} {:.2f}%".format(step, self.num_batches, step/self.num_batches*100))
                """
                # reset requires_grad to be trainable for all Ds
                # self.set_requires_grad_value(netsD, True)

                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                data = data_iter.next()
                real_imgs, atts, image_atts, class_ids, keys = prepare_data(
                    data)

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)

                print("before netG")
                fake_imgs = []
                C_losses = None
                if not self.args.kl_loss:
                    if cfg.TREE.BRANCH_NUM > 0:
                        fake_img1, h_code1 = nn.parallel.data_parallel(
                            netsG[0], (noise, atts, image_atts), self.gpus)
                        fake_imgs.append(fake_img1)
                        if self.args.split == 'train':  ##for train:
                            att_embeddings, C_losses = classifier_loss(
                                classifiers, inception_model, real_imgs[0],
                                image_atts, C_losses)
                            _, C_losses = classifier_loss(
                                classifiers, inception_model, fake_img1,
                                image_atts, C_losses)
                        else:
                            att_embeddings, _ = classifier_loss(
                                classifiers, inception_model, fake_img1,
                                image_atts)

                    if cfg.TREE.BRANCH_NUM > 1:
                        fake_img2, h_code2 = nn.parallel.data_parallel(
                            netsG[1], (h_code1, att_embeddings), self.gpus)
                        fake_imgs.append(fake_img2)
                        if self.args.split == 'train':
                            att_embeddings, C_losses = classifier_loss(
                                classifiers, inception_model, real_imgs[1],
                                image_atts, C_losses)
                            _, C_losses = classifier_loss(
                                classifiers, inception_model, fake_img1,
                                image_atts, C_losses)
                        else:
                            att_embeddings, _ = classifier_loss(
                                classifiers, inception_model, fake_img1,
                                image_atts)

                    if cfg.TREE.BRANCH_NUM > 2:
                        fake_img3 = nn.parallel.data_parallel(
                            netsG[2], (h_code2, att_embeddings), self.gpus)
                        fake_imgs.append(fake_img3)
                print("end netG")
                """
                if not self.args.kl_loss:
                    fake_imgs, C_losses = nn.parallel.data_parallel( netG, (noise, atts, image_atts,
                                                 inception_model, classifiers, imgs), self.gpus)
                else:
                    fake_imgs, mu, logvar = netG(noise, atts, image_atts) ## model内の次元が合っていない可能性。
                """

                #######################################################
                # (3) Update D network
                ######################################################
                errD_total = 0
                D_logs = ''
                errD_dic = {}
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    errD = discriminator_loss(netsD[i], real_imgs[i],
                                              fake_imgs[i], atts, real_labels,
                                              fake_labels)
                    # backward and update parameters
                    errD.backward()
                    optimizersD[i].step()
                    errD_total += errD
                    D_logs += 'errD%d: %.2f ' % (i, errD.item())
                    errD_dic['D_%d' % i] = errD.item()

                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G
                step += 1
                gen_iterations += 1

                # do not need to compute gradient for Ds
                # self.set_requires_grad_value(netsD, False)
                for i in range(len(netsG)):
                    netsG[i].zero_grad()
                print("before c backward")
                errC_total = 0
                C_logs = ''
                for i in range(len(classifiers)):
                    classifiers[i].zero_grad()
                    C_losses[i].backward(retain_graph=True)
                    optimizersC[i].step()
                    errC_total += C_losses[i]
                C_logs += 'errC_total: %.2f ' % (errC_total.item())
                print("end c backward")
                """
                for i,param in enumerate(netsG[0].parameters()):
                    if i==0:
                        print(param.grad)
                """

                ##TODO netGにgradientが溜まっているかどうかを確認せよ。

                errG_total = 0
                errG_total, G_logs, errG_dic = \
                    generator_loss(netsD, fake_imgs, real_labels, atts, errG_total)
                if self.args.kl_loss:
                    kl_loss = KL_loss(mu, logvar)
                    errG_total += kl_loss
                    G_logs += 'kl_loss: %.2f ' % kl_loss.item()
                    writer.add_scalar('kl_loss', kl_loss.item(),
                                      epoch * self.num_batches + step)

                # backward and update parameters
                errG_total.backward()
                for i in range(len(optimizersG)):
                    optimizersG[i].step()
                for i in range(len(optimizersC)):
                    optimizersC[i].step()

                errD_dic.update(errG_dic)
                writer.add_scalars('training_losses', errD_dic,
                                   epoch * self.num_batches + step)
                """
                self.save_img_results(netsG, fixed_noise, atts, image_atts, inception_model, 
                             classifiers, real_imgs, epoch, name='average') ##for debug
                """

                for i in range(len(netsG)):
                    for p, avg_p in zip(netsG[i].parameters(),
                                        avg_params_G[i]):
                        avg_p.mul_(0.999).add_(0.001, p.data)

                if gen_iterations % 100 == 0:
                    print(D_logs + '\n' + G_logs + '\n' + C_logs)
                # save images
                if gen_iterations % 1000 == 0:
                    backup_paras = []
                    for i in range(len(netsG)):
                        backup_para = copy_G_params(netsG[i])
                        backup_paras.append(backup_para)
                        load_params(netsG[i], avg_params_G[i])
                    self.save_img_results(netsG,
                                          fixed_noise,
                                          atts,
                                          image_atts,
                                          inception_model,
                                          classifiers,
                                          imgs,
                                          epoch,
                                          name='average')
                    for i in raneg(len(netsG)):
                        load_params(netsG[i], backup_paras[i])
                    #
                    # self.save_img_results(netG, fixed_noise, sent_emb,
                    #                       words_embs, mask, image_encoder,
                    #                       captions, cap_lens,
                    #                       epoch, name='current')
            end_t = time.time()

            print('''[%d/%d][%d]
                  Loss_D: %.2f Loss_G: %.2f Loss_C: %.2f Time: %.2fs''' %
                  (epoch, self.max_epoch, self.num_batches, errD_total.item(),
                   errG_total.item(), errC_total.item(), end_t - start_t))

            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:  # and epoch != 0:
                self.save_model(netsG, avg_params_G, netsD, classifiers, epoch)

        self.save_model(netsG, avg_params_G, netsD, classifiers,
                        self.max_epoch)
Example #8
0
    def train(self):
        text_encoder, image_encoder, netG, netsD, start_epoch = self.build_models(
        )
        H_rnn_model, L_rnn_model = text_encoder
        avg_param_G = copy_G_params(netG)
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)

        if cfg.TRAIN.EVQAL.B_EVQAL:
            netVQA_E = load_resnet_image_encoder(model_stage=2)
            netVQA = load_vqa_net(cfg.TRAIN.EVQAL.NET,
                                  load_program_vocab(
                                      cfg.TRAIN.EVQAL.PROGRAM_VOCAB_FILE),
                                  feature_dim=(512, 28, 28))
        else:
            netVQA_E = netVQA = None

        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        gen_iterations = 0
        # gen_iterations = start_epoch * self.num_batches
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            am_vqa_loss = AverageMeter('VQA Loss')
            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:
                # reset requires_grad to be trainable for all Ds
                # self.set_requires_grad_value(netsD, True)

                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                data = data_iter.next()
                imgs, captions, cap_lens, class_ids, bbox, label_one_hot, transformation_matrices, keys, prog = self.prepare_data(
                    data)
                class_ids = None
                batch_size = captions.size(0)

                transf_matrices = transformation_matrices[0].detach()
                transf_matrices_inv = transformation_matrices[1].detach()

                per_qa_embs, avg_qa_embs, qa_nums =\
                    Level2RNNEncodeMagic(captions, cap_lens, L_rnn_model, H_rnn_model)
                per_qa_embs, avg_qa_embs = (per_qa_embs.detach(),
                                            avg_qa_embs.detach())

                _nmaxqa = cfg.TEXT.MAX_QA_NUM
                mask = torch.ones(batch_size, _nmaxqa,
                                  dtype=torch.uint8).cuda()
                _ref = torch.arange(0, _nmaxqa).view(1,
                                                     -1).repeat(batch_size,
                                                                1).cuda()
                _targ = qa_nums.view(-1, 1).repeat(1, _nmaxqa)
                mask[_ref < _targ] = 0
                num_words = per_qa_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                inputs = (noise, avg_qa_embs, per_qa_embs, mask,
                          transf_matrices_inv, label_one_hot)
                fake_imgs, _, mu, logvar = nn.parallel.data_parallel(
                    netG, inputs, self.gpus)

                #######################################################
                # (3) Update D network
                ######################################################
                errD_total = 0
                D_logs = ''
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    if i == 0:  # NOTE only the first level Discriminator is modified.
                        errD = discriminator_loss(
                            netsD[i],
                            imgs[i],
                            fake_imgs[i],
                            avg_qa_embs,
                            real_labels,
                            fake_labels,
                            self.gpus,
                            local_labels=label_one_hot,
                            transf_matrices=transf_matrices,
                            transf_matrices_inv=transf_matrices_inv)
                    else:
                        errD = discriminator_loss(netsD[i], imgs[i],
                                                  fake_imgs[i], avg_qa_embs,
                                                  real_labels, fake_labels,
                                                  self.gpus)

                    # backward and update parameters
                    errD.backward()
                    optimizersD[i].step()
                    errD_total += errD
                    D_logs += 'errD%d: %.2f ' % (i, errD.item())

                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G
                step += 1
                gen_iterations += 1

                # do not need to compute gradient for Ds
                # self.set_requires_grad_value(netsD, False)
                netG.zero_grad()
                errG_total, G_logs = \
                    generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                                   per_qa_embs, avg_qa_embs, match_labels, qa_nums, class_ids, self.gpus,
                                   local_labels=label_one_hot, transf_matrices=transf_matrices,
                                   transf_matrices_inv=transf_matrices_inv)

                if cfg.GAN.B_CA_NET:
                    kl_loss = KL_loss(mu, logvar)
                else:
                    kl_loss = torch.FloatTensor([0.]).squeeze().cuda()

                errG_total += kl_loss
                G_logs += 'kl_loss: %.2f ' % kl_loss.item()

                if cfg.TRAIN.EVQAL.B_EVQAL:
                    fake_img_fvqa = extract_image_feats(
                        fake_imgs[-1], netVQA_E, self.gpus)
                    errVQA = VQA_loss(netVQA, fake_img_fvqa, prog['programs'],
                                      prog['answers'], self.gpus)
                else:
                    errVQA = torch.FloatTensor([0.]).squeeze().cuda()
                G_logs += 'VQA_loss: %.2f ' % errVQA.data.item()
                beta = cfg.TRAIN.EVQAL.BETA
                errG_total += (errVQA * beta)

                # backward and update parameters
                errG_total.backward()
                optimizerG.step()

                am_vqa_loss.update(errVQA.cpu().item())
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                # save images
                if gen_iterations % 100 == 0:
                    print(D_logs + '\n' + G_logs)
                if gen_iterations % 500 == 0:  # FIXME original: 1000
                    backup_para = copy_G_params(netG)
                    load_params(netG, avg_param_G)
                    self.save_img_results(imgs,
                                          netG,
                                          fixed_noise,
                                          avg_qa_embs,
                                          per_qa_embs,
                                          mask,
                                          image_encoder,
                                          captions,
                                          cap_lens,
                                          epoch,
                                          transf_matrices_inv,
                                          label_one_hot,
                                          name='average')
                    load_params(netG, backup_para)
            end_t = time.time()

            print('''[%d/%d][%d]
                  Loss_D: %.2f Loss_G: %.2f Time: %.2fs''' %
                  (epoch, self.max_epoch, self.num_batches, errD_total.item(),
                   errG_total.item(), end_t - start_t))
            if cfg.TRAIN.EVQAL.B_EVQAL:
                print('Avg. VQA Loss of this epoch: %s' % str(am_vqa_loss))
            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:  # and epoch != 0:
                self.save_model(netG, avg_param_G, netsD, optimizerG,
                                optimizersD, epoch)

        self.save_model(netG, avg_param_G, netsD, optimizerG, optimizersD,
                        epoch)
Example #9
0
    def train(self):
        text_encoder, image_encoder, netG, target_netG, netsD, start_epoch = self.build_models()
        avg_param_G = copy_G_params(netG)
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        gen_iterations = 0
        sliding_window = []
        
        # gen_iterations = start_epoch * self.num_batches
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:
                # reset requires_grad to be trainable for all Ds
                # self.set_requires_grad_value(netsD, True)

                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                data = data_iter.next()
                
                captions, cap_lens, imperfect_captions, imperfect_cap_lens, misc = data
                
                # Generate images for human-text ----------------------------------------------------------------
                data_human = [captions, cap_lens, misc]
                
                imgs, captions, cap_lens, class_ids, keys = prepare_data(data_human)

                hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs, mask, cap_lens)

                # Generate images for imperfect caption-text-------------------------------------------------------

                data_imperfect = [imperfect_captions, imperfect_cap_lens, misc]
                    
                imgs, imperfect_captions, imperfect_cap_lens, i_class_ids, imperfect_keys = prepare_data(data_imperfect)
                
                i_hidden = text_encoder.init_hidden(batch_size)
                i_words_embs, i_sent_emb = text_encoder(imperfect_captions, imperfect_cap_lens, i_hidden)
                i_words_embs, i_sent_emb = i_words_embs.detach(), i_sent_emb.detach()
                i_mask = (imperfect_captions == 0)
                i_num_words = i_words_embs.size(2)
                
                if i_mask.size(1) > i_num_words:
                    i_mask = i_mask[:, :i_num_words]
                    
                # Move tensors to the secondary device.
                #noise  = noise.to(secondary_device) # IMPORTANT! We are reusing the same noise.
                #i_sent_emb = i_sent_emb.to(secondary_device)
                #i_words_embs = i_words_embs.to(secondary_device)
                #i_mask = i_mask.to(secondary_device)
                
                # Generate images.
                imperfect_fake_imgs, _, _, _ = target_netG(noise, i_sent_emb, i_words_embs, i_mask) 
                
                # Sort the results by keys to align ------------------------------------------------------------------------
                bag = [sent_emb, real_labels, fake_labels, words_embs, class_ids]
                
                keys, captions, cap_lens, fake_imgs, _, sorted_bag = sort_by_keys(keys, captions, cap_lens, fake_imgs,\
                                                                                  None, bag)
                    
                sent_emb, real_labels, fake_labels, words_embs, class_ids = \
                            sorted_bag
                 
                imperfect_keys, imperfect_captions, imperfect_cap_lens, imperfect_fake_imgs, imgs, _ = \
                            sort_by_keys(imperfect_keys, imperfect_captions, imperfect_cap_lens, imperfect_fake_imgs, imgs,None)
                    
                #-----------------------------------------------------------------------------------------------------------
 
                #######################################################
                # (3) Update D network
                ######################################################
                errD_total = 0
                D_logs = ''
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    errD, log = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                              sent_emb, real_labels, fake_labels)
                    # backward and update parameters
                    errD.backward()
                    optimizersD[i].step()
                    errD_total += errD
                    D_logs += 'errD%d: %.2f ' % (i, errD.item())
                    D_logs += log

                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G
                step += 1
                gen_iterations += 1

                # do not need to compute gradient for Ds
                # self.set_requires_grad_value(netsD, False)
                netG.zero_grad()
                errG_total, G_logs = \
                    generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                                   words_embs, sent_emb, match_labels, cap_lens, class_ids)
                kl_loss = KL_loss(mu, logvar)
                errG_total += kl_loss
                G_logs += 'kl_loss: %.2f ' % kl_loss.item()
                
                # Shift device -----------------------------------------------------
                #for i in range(len(imgs)):
                #    imgs[i] = imgs[i].to(secondary_device)
                #    fake_imgs[i] = fake_imgs[i].to(secondary_device)
                   
                print('Discriminator loss: ', errG_total)
                
                # Compute and add ddva loss ---------------------------------------------------------------------
                neg_ddva = negative_ddva(imperfect_fake_imgs, imgs, fake_imgs)
                neg_ddva *= 10. # Scale so that the ddva score is not overwhelmed by other losses.
                errG_total += neg_ddva.to(cfg.GPU_ID)
                #G_logs += 'negative_ddva_loss: %.2f ' % neg_ddva
                #------------------------------------------------------------------------------------------------
                
                # backward and update parameters
                errG_total.backward()
                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                if len(sliding_window)==100:
                	del sliding_window[0]
                sliding_window.append(neg_ddva)
                sliding_avg_ddva =  sum(sliding_window)/len(sliding_window)

                print('sliding_window avg NEG DDVA: ',sliding_avg_ddva)
                print('Negative ddva: ', neg_ddva)
                
                #if gen_iterations % 100 == 0:
                #    print('Epoch [{}/{}] Step [{}/{}]'.format(epoch, self.max_epoch, step,
                                                              self.num_batches) + ' ' + D_logs + ' ' + G_logs)
                
                # Copy parameters to the target network.
                #if gen_iterations % 4 == 0:
                load_params(target_netG, copy_G_params(netG))
                # Disable training in the target network:
                for p in target_netG.parameters():
                    p.requires_grad = False
                    
            end_t = time.time()

            #print('''[%d/%d] Loss_D: %.2f Loss_G: %.2f Time: %.2fs''' % (
            #    epoch, self.max_epoch, errD_total.item(), errG_total.item(), end_t - start_t))
            #print('-' * 89)
            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:  # and epoch != 0:
                self.save_model(netG, avg_param_G, netsD, epoch)
Example #10
0
    def train(self):
        text_encoder, image_encoder, netG, netsD, start_epoch = self.build_models(
        )
        avg_param_G = copy_G_params(netG)
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = (torch.FloatTensor(batch_size, nz).requires_grad_())
        fixed_noise = (torch.FloatTensor(batch_size,
                                         nz).normal_(0, 1).requires_grad_())
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        gen_iterations = 0
        # gen_iterations = start_epoch * self.num_batches
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:
                # reset requires_grad to be trainable for all Ds
                # self.set_requires_grad_value(netsD, True)

                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                data = data_iter.next()
                imgs, captions, cap_lens, class_ids, keys = prepare_data(data)

                # hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                words_embs, sent_emb = text_encoder(captions, cap_lens)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
                if self.audio_flag:
                    mask = torch.ByteTensor([[0] * cap_len.item() + [1] *
                                             (32 - cap_len.item())
                                             for cap_len in cap_lens])
                    mask = mask.to(words_embs.device)
                else:
                    mask = (captions == 0)
                    num_words = words_embs.size(2)
                    if mask.size(1) > num_words:
                        mask = mask[:, :num_words]

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs,
                                                mask)

                #######################################################
                # (3) Update D network
                ######################################################
                errD_total = 0
                errD_seq = []
                D_logs = ''
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                              sent_emb, real_labels,
                                              fake_labels)
                    # backward and update parameters
                    errD.backward()
                    optimizersD[i].step()
                    errD_seq.append(errD)
                    errD_total += errD
                    D_logs += 'errD%d: %.2f ' % (i, errD.item())

                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G
                step += 1
                gen_iterations += 1

                # do not need to compute gradient for Ds
                # self.set_requires_grad_value(netsD, False)
                netG.zero_grad()
                errG_total, G_logs, lossG_seq = \
                    generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                                   words_embs, sent_emb, match_labels, cap_lens, class_ids)
                kl_loss = KL_loss(mu, logvar)
                errG_total += kl_loss
                G_logs += 'kl_loss: %.2f ' % kl_loss.item()
                # backward and update parameters
                errG_total.backward()
                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                self.writer.add_scalars(main_tag="loss_d",
                                        tag_scalar_dict={"loss_d": errD_total},
                                        global_step=gen_iterations)
                for idx in range(len(netsD)):
                    self.writer.add_scalars(main_tag="loss_d",
                                            tag_scalar_dict={
                                                "loss_d_{:d}".format(idx):
                                                errD_seq[idx]
                                            },
                                            global_step=gen_iterations)
                self.writer.add_scalars(main_tag="loss_g",
                                        tag_scalar_dict={
                                            "loss_g": errG_total,
                                            "loss_g_gan": lossG_seq[0],
                                            "loss_g_w": lossG_seq[1],
                                            "loss_g_s": lossG_seq[2],
                                            "kl_loss": kl_loss
                                        },
                                        global_step=gen_iterations)
                if gen_iterations % 100 == 0:
                    self.logger.info(D_logs + '\n' + G_logs)
                # save images
                if gen_iterations % 1000 == 0:
                    backup_para = copy_G_params(netG)
                    load_params(netG, avg_param_G)
                    self.save_img_results(netG,
                                          fixed_noise,
                                          sent_emb,
                                          words_embs,
                                          mask,
                                          image_encoder,
                                          captions,
                                          cap_lens,
                                          epoch,
                                          name='average')
                    load_params(netG, backup_para)
                    self.logger.info("save image result...")
                    #
                    # self.save_img_results(netG, fixed_noise, sent_emb,
                    #                       words_embs, mask, image_encoder,
                    #                       captions, cap_lens,
                    #                       epoch, name='current')
            end_t = time.time()

            self.logger.info('''[{}/{}][{}]
                Loss_D: {:.2f} Loss_G: {:.2f} Time: {:.2f}'''.format(
                epoch, self.max_epoch, self.num_batches, errD_total.item(),
                errG_total.item(), end_t - start_t))

            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:  # and epoch != 0:
                self.save_model(netG, avg_param_G, netsD, epoch)

        self.save_model(netG, avg_param_G, netsD, self.max_epoch)