Esempio n. 1
0
    def save_model(self, netG, avg_param_G, netsD, optimG, optimsD, epoch, max_to_keep=5):
        netDs_state_dicts = []
        optimDs_state_dicts = []
        for i in range(len(netsD)):
            netD = netsD[i]
            optimD = optimsD[i]
            netDs_state_dicts.append(netD.state_dict())
            optimDs_state_dicts.append(optimD.state_dict())

        backup_para = copy_G_params(netG)
        load_params(netG, avg_param_G)
        checkpoint = {
            'epoch': epoch,
            'netG': netG.state_dict(),
            'optimG': optimG.state_dict(),
            'netD': netDs_state_dicts,
            'optimD': optimDs_state_dicts}
        torch.save(checkpoint, "{}/checkpoint_{:04}.pth".format(self.model_dir, epoch))
        print('Save G/D models')

        load_params(netG, backup_para)

        if max_to_keep is not None and max_to_keep > 0:
            checkpoint_list = sorted([ckpt for ckpt in glob.glob(self.model_dir + "/" + '*.pth')])
            while len(checkpoint_list) > max_to_keep:
                os.remove(checkpoint_list[0])
                checkpoint_list = checkpoint_list[1:]
Esempio n. 2
0
    def save_model(self, netG, avg_param_G, netD, epoch):
        backup_para = copy_G_params(netG)
        load_params(netG, avg_param_G)
        torch.save(netG.state_dict(),
                   '%s/netG_epoch_%d.pth' % (self.model_dir, epoch))
        load_params(netG, backup_para)

        torch.save(netD.state_dict(), '%s/netD.pth' % (self.model_dir))
        print('Save G/Ds models.')
Esempio n. 3
0
 def save_model(self, netG, avg_param_G, netINSD, netGLBD, epoch):
     backup_para = copy_G_params(netG)
     load_params(netG, avg_param_G)
     torch.save(netG.state_dict(),
                '%s/netG_epoch_%d.pth' % (self.model_dir, epoch))
     load_params(netG, backup_para)
     #
     torch.save(netINSD.state_dict(), '%s/netINSD.pth' % (self.model_dir))
     #
     torch.save(netGLBD.state_dict(), '%s/netGLBD.pth' % (self.model_dir))
Esempio n. 4
0
    def save_model(self, netDCM, avg_param_C, netD, epoch):
        backup_para = copy_G_params(netDCM)
        load_params(netDCM, avg_param_C)
        torch.save(netDCM.state_dict(),
                   '%s/netC_epoch_%d.pth' % (self.model_dir, epoch))
        load_params(netDCM, backup_para)

        torch.save(netD.state_dict(),
                   '%s/netD_epoch_%d.pth' % (self.model_dir, epoch))

        print('Save C/D models.')
Esempio n. 5
0
 def save_model(self, netG, avg_param_G, netsD, epoch):
     backup_para = copy_G_params(netG)
     load_params(netG, avg_param_G)
     torch.save(netG.state_dict(),
                '%s/netG_epoch_%d.pth' % (self.model_dir, epoch))
     load_params(netG, backup_para)
     #
     for i in range(len(netsD)):
         netD = netsD[i]
         torch.save(netD.state_dict(),
                    '%s/netD%d.pth' % (self.model_dir, i))
     print('Save G/Ds models.')
Esempio n. 6
0
    def save_model(self,
                   netG,
                   avg_param_G,
                   netsD,
                   optimG,
                   optimsD,
                   epoch,
                   max_to_keep=5,
                   interval=5):
        netDs_state_dicts = []
        optimDs_state_dicts = []
        for i in range(len(netsD)):
            netD = netsD[i]
            optimD = optimsD[i]
            netDs_state_dicts.append(netD.state_dict())
            optimDs_state_dicts.append(optimD.state_dict())

        backup_para = copy_G_params(netG)
        load_params(netG, avg_param_G)
        checkpoint = {
            'epoch': epoch,
            'netG': netG.state_dict(),
            'optimG': optimG.state_dict(),
            'netD': netDs_state_dicts,
            'optimD': optimDs_state_dicts
        }
        torch.save(checkpoint,
                   "{}/checkpoint_{:04}.pth".format(self.model_dir, epoch))
        logger.info('Save G/D models')

        load_params(netG, backup_para)

        if max_to_keep is not None and max_to_keep > 0:
            checkpoint_list_all = sorted(
                [ckpt for ckpt in glob.glob(self.model_dir + "/" + '*.pth')])
            checkpoint_list = []
            checkpoint_list_tmp = []

            for ckpt in checkpoint_list_all:
                ckpt_epoch = int(ckpt[-8:-4])
                if ckpt_epoch % interval == 0:
                    checkpoint_list.append(ckpt)
                else:
                    checkpoint_list_tmp.append(ckpt)

            while len(checkpoint_list) > max_to_keep:
                os.remove(checkpoint_list[0])
                checkpoint_list = checkpoint_list[1:]

            ckpt_tmp = len(checkpoint_list_tmp)
            for idx in range(ckpt_tmp - 1):
                os.remove(checkpoint_list_tmp[idx])
Esempio n. 7
0
 def save_model(self, netG, avg_param_G, netsD, epoch):
     save_dir = '/home/adsueiitm/experiments/exp2/checkpoints'
     backup_para = copy_G_params(netG)
     load_params(netG, avg_param_G)
     
     torch.save(netG.state_dict(),
         '%s/netG_epoch_%d.pth' % (save_dir, epoch))
     load_params(netG, backup_para)
     #
     for i in range(len(netsD)):
         netD = netsD[i]
         torch.save(netD.state_dict(),
             '%s/netD%d.pth' % (self.model_dir, i))
     print('Save G/Ds models.')
Esempio n. 8
0
    def save_model(self, netG, avg_param_G, image_encoder, text_encoder, netsD,
                   epoch, cap_model, optimizerC, optimizerI, optimizerT,
                   lr_schedulerC, lr_schedulerI, lr_schedulerT, optimizerG,
                   optimizersD):
        backup_para = copy_G_params(netG)
        load_params(netG, avg_param_G)

        torch.save(({
            'model': netG.state_dict(),
            'optimizer': optimizerG.state_dict()
        }, '%s/netG_epoch_%d.pth' % (self.model_dir, epoch)))

        load_params(netG, backup_para)
        #
        for i in range(len(netsD)):
            netD = netsD[i]
            optD = optimizersD[i]

            torch.save(({
                'model': netD.state_dict(),
                'optimizer': optD.state_dict()
            }, '%s/netD%d.pth' % (self.model_dir, i)))

        print('Save G/Ds models.')
        # save caption model here
        torch.save(
            {
                'model': cap_model.state_dict(),
                'optimizer': optimizerC.state_dict(),
                'lr_scheduler': lr_schedulerC.state_dict(),
            }, '%s/cap_model%d.pth' % (self.model_dir, epoch))

        # save image encoder model here
        torch.save(
            {
                'model': image_encoder.state_dict(),
                'optimizer': optimizerI.state_dict(),
                'lr_scheduler': lr_schedulerI.state_dict(),
            }, '%s/image_encoder%d.pth' % (self.model_dir, epoch))

        # save text encoder model here
        torch.save(
            {
                'model': text_encoder.state_dict(),
                'optimizer': optimizerT.state_dict(),
                'lr_scheduler': lr_schedulerT.state_dict(),
            }, '%s/text_encoder%d.pth' % (self.model_dir, epoch))
Esempio n. 9
0
 def save_model(self, netG, avg_param_G, netsD, epoch):
     backup_para = copy_G_params(netG)
     load_params(netG, avg_param_G)
     myDriveAttnGanModel = '/content/drive/My Drive/cubModelGAN'
     torch.save(netG.state_dict(),
                '%s/netG_epoch_%d.pth' % (self.model_dir, epoch))
     torch.save(netG.state_dict(),
                '%s/netG_epoch_%d.pth' % (myDriveAttnGanModel, epoch))
     load_params(netG, backup_para)
     #
     for i in range(len(netsD)):
         netD = netsD[i]
         torch.save(netD.state_dict(),
                    '%s/netD%d.pth' % (self.model_dir, i))
         torch.save(netD.state_dict(),
                    '%s/netD%d.pth' % (myDriveAttnGanModel, i))
     print('Save G/Ds models.')
Esempio n. 10
0
 def save_model(self, netG, avg_param_G, netsD, epoch):
     backup_para = copy_G_params(netG)
     load_params(netG, avg_param_G)
     _pathG = '%s/netG_epoch_%d.pth' % (self.model_dir, epoch)
     if cfg.GAN.B_STYLEGEN:
         torch.save(
             {
                 'w_ewma': netG.w_ewma,  # .to( 'cpu' ),
                 'netG_state_dict': netG.state_dict()
             },
             _pathG)
     else:
         torch.save(netG.state_dict(), _pathG)
     load_params(netG, backup_para)
     #
     for i in range(len(netsD)):
         netD = netsD[i]
         torch.save(netD.state_dict(),
                    '%s/netD%d.pth' % (self.model_dir, i))
     print('Save G/Ds models.')
Esempio n. 11
0
    def save_model(self, netG, avg_param_G, netsD, epoch, text_encoder,
                   image_encoder):
        backup_para = copy_G_params(netG)
        load_params(netG, avg_param_G)
        torch.save(netG.state_dict(),
                   '%s/netG_epoch_%d.pth' % (self.model_dir, epoch))
        load_params(netG, backup_para)
        #
        for i in range(len(netsD)):
            netD = netsD[i]
            torch.save(netD.state_dict(),
                       '%s/netD%d.pth' % (self.model_dir, i))
        print('Save G/Ds models.')

        if cfg.ANNO_PATH:
            torch.save(text_encoder,
                       '%s/text_encoder_%d.pth' % (self.model_dir, epoch))
            torch.save(image_encoder,
                       '%s/image_encoder_%d.pth' % (self.model_dir, epoch))
            print('Save text/image encoder')
    def save_model(self, netsG, avg_params_G, netsD, classifiers, epoch):
        mkdir_p(self.model_dir)

        for i in range(len(classifiers)):
            classifier = classifiers[i]
            torch.save(classifier.state_dict(),
                       '%s/classifier_%d.pth' % (self.model_dir, i))

        for i in range(len(netsG)):
            backup_para = copy_G_params(netsG[i])
            load_params(netsG[i], avg_params_G[i])
            torch.save(netsG[i].state_dict(),
                       '%s/netsG%d_epoch_%d.pth' % (self.model_dir, i, epoch))
            load_params(netsG[i], backup_para)
        #
        for i in range(len(netsD)):
            netD = netsD[i]
            torch.save(netD.state_dict(),
                       '%s/netD%d.pth' % (self.model_dir, i))

        print('Save Gs/Ds/classifiers models.')
Esempio n. 13
0
    def train(self):
        text_encoder, image_encoder, netG, netsD, start_epoch = self.build_models()
        avg_param_G = copy_G_params(netG)
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        gen_iterations = 0
        loss_dict = {}
        # gen_iterations = start_epoch * self.num_batches
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:
                # reset requires_grad to be trainable for all Ds
                # self.set_requires_grad_value(netsD, True)

                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                data = data_iter.next()
                imgs, captions, cap_lens, class_ids, keys = prepare_data(data)

                hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs, mask, cap_lens)

                #######################################################
                # (3) Update D network
                ######################################################
                errD_total = 0
                D_logs = ''
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    errD, log, d_dict = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                                           sent_emb, real_labels, fake_labels)
                    # backward and update parameters
                    errD.backward()
                    optimizersD[i].step()
                    errD_total += errD
                    D_logs += 'errD%d: %.2f ' % (i, errD.item())
                    D_logs += log
                    loss_dict['Real_Acc_{}'.format(i)] = d_dict['Real_Acc']
                    loss_dict['Fake_Acc_{}'.format(i)] = d_dict['Fake_Acc']
                    loss_dict['errD_{}'.format(i)] = errD.item()

                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G
                step += 1
                gen_iterations += 1

                # do not need to compute gradient for Ds
                # self.set_requires_grad_value(netsD, False)
                netG.zero_grad()
                errG_total, G_logs, g_dict = \
                    generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                                   words_embs, sent_emb, match_labels, cap_lens, class_ids)
                kl_loss = KL_loss(mu, logvar)
                errG_total += kl_loss
                G_logs += 'kl_loss: %.2f ' % kl_loss.item()
                loss_dict.update(g_dict)
                loss_dict['kl_loss'] = kl_loss.item()
                # backward and update parameters
                errG_total.backward()
                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                if gen_iterations % 100 == 0:
                    print('Epoch [{}/{}] Step [{}/{}]'.format(epoch, self.max_epoch, step,
                                                              self.num_batches) + ' ' + D_logs + ' ' + G_logs)
                if self.logger:
                    self.logger.log(loss_dict)
                # save images
                if gen_iterations % 10000 == 0:
                    backup_para = copy_G_params(netG)
                    load_params(netG, avg_param_G)
                    self.save_img_results(netG, fixed_noise, sent_emb, words_embs, mask, image_encoder,
                                          captions, cap_lens, gen_iterations)
                    load_params(netG, backup_para)
                    #
                    # self.save_img_results(netG, fixed_noise, sent_emb,
                    #                       words_embs, mask, image_encoder,
                    #                       captions, cap_lens,
                    #                       epoch, name='current')
                # if gen_iterations % 1000 == 0:
                #    time.sleep(30)
                # if gen_iterations % 10000 == 0:
                #    time.sleep(160)
            end_t = time.time()

            print('''[%d/%d] Loss_D: %.2f Loss_G: %.2f Time: %.2fs''' % (
                epoch, self.max_epoch, errD_total.item(), errG_total.item(), end_t - start_t))
            print('-' * 89)
            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:  # and epoch != 0:
                self.save_model(netG, avg_param_G, netsD, epoch)

        self.save_model(netG, avg_param_G, netsD, self.max_epoch)
Esempio n. 14
0
    def train(self):
        wandb.init(name=cfg.EXP_NAME, project='AttnGAN', config=cfg, dir='../logs')

        text_encoder, image_encoder, netG, netsD, start_epoch = self.build_models()
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)
        text_encoder, image_encoder, netG, netsD, optimizerG, optimizersD =  \
            self.apply_apex(text_encoder, image_encoder, netG, netsD, optimizerG, optimizersD)
        # add watch
        wandb.watch(netG)
        for D in netsD:
            wandb.watch(D)

        avg_param_G = copy_G_params(netG)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        log_dict = {}
        gen_iterations = 0
        # gen_iterations = start_epoch * self.num_batches
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:
                # reset requires_grad to be trainable for all Ds
                # self.set_requires_grad_value(netsD, True)

                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                data = data_iter.next()
                imgs, captions, cap_lens, class_ids, keys = prepare_data(data)

                hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs, mask)

                #######################################################
                # (3) Update D network
                ######################################################
                errD_total = 0
                D_logs = ''
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                              sent_emb, real_labels, fake_labels)
                    # backward and update parameters
                    if cfg.APEX:
                        from apex import amp
                        with amp.scale_loss(errD, optimizersD[i], loss_id=i) as errD_scaled:
                            errD_scaled.backward()
                    else:
                        errD.backward()
                    optimizersD[i].step()
                    errD_total += errD
                    D_logs += 'errD%d: %.2f ' % (i, errD.item())
                    log_name = 'errD_{}'.format(i)
                    log_dict[log_name] = errD.item()

                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G
                step += 1
                gen_iterations += 1

                # do not need to compute gradient for Ds
                # self.set_requires_grad_value(netsD, False)
                netG.zero_grad()
                errG_total, G_logs, G_log_dict = \
                    generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                                   words_embs, sent_emb, match_labels, cap_lens, class_ids)
                kl_loss = KL_loss(mu, logvar)
                errG_total += kl_loss
                G_logs += 'kl_loss: %.2f ' % kl_loss.item()
                log_dict.update(G_log_dict)
                log_dict['kl_loss'] = kl_loss.item()
                # backward and update parameters
                if cfg.APEX:
                    from apex import amp
                    with amp.scale_loss(errG_total, optimizerG, loss_id=len(netsD)) as errG_scaled:
                        errG_scaled.backward()
                else:
                    errG_total.backward()
                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                wandb.log(log_dict)
                if gen_iterations % 100 == 0:
                    print(D_logs + '\n' + G_logs)
                    wandb.save('logs.ckpt')
                # save images
                if gen_iterations % 1000 == 0:
                    backup_para = copy_G_params(netG)
                    load_params(netG, avg_param_G)
                    self.save_img_results(netG, fixed_noise, sent_emb,
                                          words_embs, mask, image_encoder,
                                          captions, cap_lens, epoch, name='average')
                    load_params(netG, backup_para)
                    #
                    # self.save_img_results(netG, fixed_noise, sent_emb,
                    #                       words_embs, mask, image_encoder,
                    #                       captions, cap_lens,
                    #                       epoch, name='current')
            end_t = time.time()

            print('''[%d/%d][%d]
                  Loss_D: %.2f Loss_G: %.2f Time: %.2fs'''
                  % (epoch, self.max_epoch, self.num_batches,
                     errD_total.item(), errG_total.item(),
                     end_t - start_t))

            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:  # and epoch != 0:
                self.save_model(netG, avg_param_G, netsD, epoch)

        self.save_model(netG, avg_param_G, netsD, self.max_epoch)
Esempio n. 15
0
    def train(self):
        text_encoder, image_encoder, netG, netsPatD, netsShpD, netObjSSD, netObjLSD, \
            start_epoch = self.build_models()
        avg_param_G = copy_G_params(netG)

        optimizerG, optimizersPatD, optimizersShpD, optimizerObjSSD, optimizerObjLSD = \
            self.define_optimizers(netG, netsPatD, netsShpD, netObjSSD, netObjLSD)

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        match_labels = self.prepare_labels()
        clabels_emb = self.prepare_cat_emb()

        gen_iterations = 0
        # gen_iterations = start_epoch * self.num_batches
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            predictions = []
            while step < self.num_batches:
                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                data = data_iter.next()
                imgs, captions, glove_captions, cap_lens, hmaps, rois, fm_rois, \
                    num_rois, bt_masks, fm_bt_masks, class_ids, keys = prepare_data(data)

                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                max_len = int(torch.max(cap_lens))
                words_embs, sent_emb = text_encoder(captions, cap_lens,
                                                    max_len)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()

                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                # glove_words_embs: batch_size x 50 (glove dim) x seq_len
                glove_words_embs = self.glove_emb(glove_captions.view(-1))
                glove_words_embs = glove_words_embs.detach().view(
                    glove_captions.size(0), glove_captions.size(1), -1)
                glove_words_embs = glove_words_embs[:, :num_words].transpose(
                    1, 2)

                # clabels_feat: batch x 50 (glove dim) x max_num_roi x 1
                clabels_feat = form_clabels_feat(clabels_emb, rois[0],
                                                 num_rois)

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                glb_max_num_roi = int(torch.max(num_rois))
                fake_imgs, bt_c_codes, _, _, mu, logvar = netG(
                    noise, sent_emb, words_embs, glove_words_embs,
                    clabels_feat, mask, hmaps, rois, fm_rois, num_rois,
                    bt_masks, fm_bt_masks, glb_max_num_roi)
                bt_c_codes = [bt_c_code.detach() for bt_c_code in bt_c_codes]

                #######################################################
                # (3-1) Update PatD network
                ######################################################
                errPatD_total = 0
                PatD_logs = ''
                for i in range(len(netsPatD)):
                    netsPatD[i].zero_grad()
                    errPatD = patD_loss(netsPatD[i], imgs[i], fake_imgs[i],
                                        sent_emb)
                    errPatD.backward()
                    optimizersPatD[i].step()
                    errPatD_total += errPatD
                    PatD_logs += 'errPatD%d: %.2f ' % (i, errPatD.item())

                #######################################################
                # (3-2) Update ShpD network
                ######################################################
                errShpD_total = 0
                ShpD_logs = ''
                for i in range(len(netsShpD)):
                    netsShpD[i].zero_grad()
                    hmap = hmaps[i]
                    roi = rois[i]
                    errShpD = shpD_loss(netsShpD[i], imgs[i], fake_imgs[i],
                                        hmap, roi, num_rois)
                    errShpD.backward()
                    optimizersShpD[i].step()
                    errShpD_total += errShpD
                    ShpD_logs += 'errShpD%d: %.2f ' % (i, errShpD.item())

                #######################################################
                # (3-3) Update ObjSSD network
                ######################################################
                netObjSSD.zero_grad()
                errObjSSD = objD_loss(netObjSSD, imgs[-1], fake_imgs[-1],
                                      hmaps[-1], clabels_emb, bt_c_codes[-1],
                                      rois[0], num_rois)
                if float(errObjSSD) > 0:
                    errObjSSD.backward()
                    optimizerObjSSD.step()
                    ObjSSD_logs = 'errSSACD: %.2f ' % (errObjSSD.item())

                #######################################################
                # (3-4) Update ObjLSD network
                ######################################################
                netObjLSD.zero_grad()
                errObjLSD = objD_loss(netObjLSD,
                                      imgs[-1],
                                      fake_imgs[-1],
                                      hmaps[-1],
                                      clabels_emb,
                                      bt_c_codes[-1],
                                      fm_rois,
                                      num_rois,
                                      is_large_scale=True)
                if float(errObjLSD) > 0:
                    errObjLSD.backward()
                    optimizerObjLSD.step()
                    ObjLSD_logs = 'errObjLSD: %.2f ' % (errObjLSD.item())

                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G
                netG.zero_grad()
                errG_total, G_logs = \
                    G_loss(netsPatD, netsShpD, netObjSSD, netObjLSD, image_encoder, fake_imgs,
                                   hmaps, words_embs, sent_emb, clabels_emb, bt_c_codes[-1],
                                   match_labels, cap_lens, class_ids, rois[0], fm_rois, num_rois)

                kl_loss = KL_loss(mu, logvar)
                errG_total += kl_loss
                G_logs += 'kl_loss: %.2f ' % kl_loss.item()
                # backward and update parameters
                errG_total.backward()
                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                #######################################################
                # (5) Print and display
                ######################################################
                images = fake_imgs[-1].detach()
                pred = self.inception_model(images)
                predictions.append(pred.data.cpu().numpy())

                step += 1
                gen_iterations += 1

                if gen_iterations % self.print_interval == 0:
                    print('[%d/%d][%d]' %
                          (epoch, self.max_epoch, gen_iterations) + '\n' +
                          PatD_logs + '\n' + ShpD_logs + '\n' + ObjSSD_logs +
                          '\n' + ObjLSD_logs + '\n' + G_logs)
                # save images
                if gen_iterations % self.display_interval == 0:
                    backup_para = copy_G_params(netG)
                    load_params(netG, avg_param_G)
                    self.save_img_results(netG,
                                          fixed_noise,
                                          sent_emb,
                                          words_embs,
                                          glove_words_embs,
                                          clabels_feat,
                                          mask,
                                          hmaps,
                                          rois,
                                          fm_rois,
                                          num_rois,
                                          bt_masks,
                                          fm_bt_masks,
                                          image_encoder,
                                          captions,
                                          cap_lens,
                                          epoch,
                                          name='average')
                    load_params(netG, backup_para)

            end_t = time.time()

            print(
                '''[%d/%d][%d]
                  Loss_PatD: %.2f Loss_ShpD: %.2f Loss_ObjSSD: %.2f Loss_ObjLSD: %.2f Loss_G: %.2f Time: %.2fs'''
                %
                (epoch, self.max_epoch, self.num_batches, errPatD_total.item(),
                 errShpD_total.item(), errObjSSD.item(), errObjLSD.item(),
                 errG_total.item(), end_t - start_t))

            predictions = np.concatenate(predictions, 0)
            mean, std = compute_inception_score(predictions,
                                                min(10, self.batch_size))
            mean_conf, std_conf = \
                negative_log_posterior_probability(predictions, min(10, self.batch_size))

            fullpath = '%s/scores_%d.txt' % (self.score_dir, epoch)
            with open(fullpath, 'w') as fp:
                fp.write('mean, std, mean_conf, std_conf \n')
                fp.write('%f, %f, %f, %f' % (mean, std, mean_conf, std_conf))

            print('inception_score: mean, std, mean_conf, std_conf')
            print('inception_score: %f, %f, %f, %f' %
                  (mean, std, mean_conf, std_conf))

            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:  # and epoch != 0:
                self.save_model(netG, avg_param_G, netsPatD, netsShpD,
                                netObjSSD, netObjLSD, epoch)

        self.save_model(netG, avg_param_G, netsPatD, netsShpD, netObjSSD,
                        netObjLSD, self.max_epoch)
Esempio n. 16
0
    def train(self):
        writer = SummaryWriter('runs/architecture')
        text_encoder, image_encoder, netG, netsD, start_epoch = self.build_models()
        avg_param_G = copy_G_params(netG)
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        gen_iterations = 0
        # gen_iterations = start_epoch * self.num_batches
        for epoch in range(start_epoch, self.max_epoch):
            print("=================================START TRAINING========================================")
            print("++++++++++++++++++++++++++++++++++%d+++++++++++++++++++++++++++++++++++++++++++++++++++\n" % gen_iterations)
            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:
                # reset requires_grad to be trainable for all Ds
                # self.set_requires_grad_value(netsD, True)

                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                data = data_iter.next()
                imgs, captions, cap_lens, class_ids, keys = prepare_data(data)

                hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]
            
                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs, mask)
            
                #######################################################
                # (3) Update D network
                ######################################################
                errD_total = 0
                D_logs = ''
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                              sent_emb, real_labels, fake_labels)
                    
                    # backward and update parameters
                    errD.backward()
                    optimizersD[i].step()
                    errD_total += errD
                    D_logs += 'errD%d: %.2f ' % (i, errD.data)

                    writer.add_scalar('data/errD%d' % i, errD.data.item(), gen_iterations)

                
                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G
                step += 1
                gen_iterations += 1

                # do not need to compute gradient for Ds
                # self.set_requires_grad_value(netsD, False)
                netG.zero_grad()
                errG_total, G_logs = \
                    generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                                   words_embs, sent_emb, match_labels, cap_lens, class_ids)
                kl_loss = KL_loss(mu, logvar)
                errG_total += kl_loss
                # CHANGED G_logs += 'kl_loss: %.2f ' % kl_loss.data[0]
                G_logs += 'kl_loss: %.2f ' % kl_loss.data
                # backward and update parameters
                errG_total.backward()
                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                if gen_iterations % 100 == 0:
                    print(D_logs + '\n' + G_logs)
                # save images
                if gen_iterations % 1000 == 0:
                    backup_para = copy_G_params(netG)
                    load_params(netG, avg_param_G)
                    self.save_img_results(netG, fixed_noise, sent_emb,
                                          words_embs, mask, image_encoder,
                                          captions, cap_lens, epoch, name='average')
                    load_params(netG, backup_para)
                    #
                    # self.save_img_results(netG, fixed_noise, sent_emb,
                    #                       words_embs, mask, image_encoder,
                    #                       captions, cap_lens,
                    #                       epoch, name='current')
        
            end_t = time.time()
            writer.add_scalar('data/Loss_D', errD_total.data.item(), epoch)
            writer.add_scalar('data/Loss_G', errG_total.data.item(), epoch)
            print('''[%d/%d][%d]
                  Loss_D: %.2f Loss_G: %.2f Time: %.2fs'''
                  % (epoch, self.max_epoch, self.num_batches,
                    errD_total.data, errG_total.data,
                     # CHANGED errD_total.data[0], errG_total.data[0],
                     end_t - start_t))

            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:  # and epoch != 0:
                self.save_model(netG, avg_param_G, netsD, epoch)

        self.save_model(netG, avg_param_G, netsD, self.max_epoch)
        writer.export_scalars_to_json("./all_scalars.json")
        writer.close()
Esempio n. 17
0
    def train(self):
        text_encoder, image_encoder, netG, netsD, start_epoch, style_loss = self.build_models(
        )
        avg_param_G = copy_G_params(netG)
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        gen_iterations = 0
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:

                data = data_iter.next()
                imgs, captions, cap_lens, class_ids, keys, wrong_caps, \
                                wrong_caps_len, wrong_cls_id = prepare_data(data)

                hidden = text_encoder.init_hidden(batch_size)
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()

                # wrong word and sentence embeddings
                w_words_embs, w_sent_emb = text_encoder(
                    wrong_caps, wrong_caps_len, hidden)
                w_words_embs, w_sent_emb = w_words_embs.detach(
                ), w_sent_emb.detach()

                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                noise.data.normal_(0, 1)
                fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs,
                                                mask)

                errD_total = 0
                D_logs = ''
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                              sent_emb, real_labels,
                                              fake_labels, words_embs,
                                              cap_lens, image_encoder,
                                              class_ids, w_words_embs,
                                              wrong_caps_len, wrong_cls_id)
                    # backward and update parameters
                    errD.backward(retain_graph=True)
                    optimizersD[i].step()
                    errD_total += errD
                    D_logs += 'errD%d: %.2f ' % (i, errD)

                step += 1
                gen_iterations += 1

                netG.zero_grad()
                errG_total, G_logs ,w_loss, s_loss = \
                    generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                                   words_embs, sent_emb, match_labels, cap_lens, class_ids, style_loss, imgs)
                kl_loss = KL_loss(mu, logvar)
                errG_total += kl_loss
                G_logs += 'kl_loss: %.2f ' % kl_loss

                errG_total.backward()
                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                if gen_iterations % 100 == 0:
                    print(D_logs + '\n' + G_logs)

                # save images
                if gen_iterations % 1000 == 0:
                    backup_para = copy_G_params(netG)
                    load_params(netG, avg_param_G)
                    self.save_img_results(netG,
                                          fixed_noise,
                                          sent_emb,
                                          words_embs,
                                          mask,
                                          image_encoder,
                                          captions,
                                          cap_lens,
                                          epoch,
                                          name='average')
                    load_params(netG, backup_para)

            end_t = time.time()

            print('''[%d/%d][%d]
                  Loss_D: %.2f Loss_G: %.2f Time: %.2fs''' %
                  (epoch, self.max_epoch, self.num_batches, errD_total,
                   errG_total, end_t - start_t))

            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:
                self.save_model(netG, avg_param_G, netsD, epoch)

        self.save_model(netG, avg_param_G, netsD, self.max_epoch)
Esempio n. 18
0
    def train(self):
        text_encoder, image_encoder, netG, target_netG, netsD, start_epoch, style_loss = self.build_models(
        )
        avg_param_G = copy_G_params(netG)
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        gen_iterations = 0

        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:

                data = data_iter.next()

                captions, cap_lens, imperfect_captions, imperfect_cap_lens, misc = data

                # Generate images for human-text ----------------------------------------------------------------
                data_human = [captions, cap_lens, misc]

                imgs, captions, cap_lens, class_ids, keys, wrong_caps, \
                                wrong_caps_len, wrong_cls_id = prepare_data(data_human)

                hidden = text_encoder.init_hidden(batch_size)
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()

                # wrong word and sentence embeddings
                w_words_embs, w_sent_emb = text_encoder(
                    wrong_caps, wrong_caps_len, hidden)
                w_words_embs, w_sent_emb = w_words_embs.detach(
                ), w_sent_emb.detach()

                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                noise.data.normal_(0, 1)
                fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs,
                                                mask)

                # Generate images for imperfect caption-text-------------------------------------------------------

                data_imperfect = [imperfect_captions, imperfect_cap_lens, misc]

                imgs, imperfect_captions, imperfect_cap_lens, i_class_ids, imperfect_keys, i_wrong_caps,\
                            i_wrong_caps_len, i_wrong_cls_id = prepare_data(data_imperfect)

                i_hidden = text_encoder.init_hidden(batch_size)
                i_words_embs, i_sent_emb = text_encoder(
                    imperfect_captions, imperfect_cap_lens, i_hidden)
                i_words_embs, i_sent_emb = i_words_embs.detach(
                ), i_sent_emb.detach()
                i_mask = (imperfect_captions == 0)
                i_num_words = i_words_embs.size(2)

                if i_mask.size(1) > i_num_words:
                    i_mask = i_mask[:, :i_num_words]

                # Move tensors to the secondary device.
                noise = noise.to(secondary_device
                                 )  # IMPORTANT! We are reusing the same noise.
                i_sent_emb = i_sent_emb.to(secondary_device)
                i_words_embs = i_words_embs.to(secondary_device)
                i_mask = i_mask.to(secondary_device)

                # Generate images.
                imperfect_fake_imgs, _, _, _ = target_netG(
                    noise, i_sent_emb, i_words_embs, i_mask)

                # Sort the results by keys to align ------------------------------------------------------------------------
                bag = [
                    sent_emb, real_labels, fake_labels, words_embs, class_ids,
                    w_words_embs, wrong_caps_len, wrong_cls_id
                ]

                keys, captions, cap_lens, fake_imgs, _, sorted_bag = sort_by_keys(keys, captions, cap_lens, fake_imgs,\
                                                                                  None, bag)

                sent_emb, real_labels, fake_labels, words_embs, class_ids, w_words_embs, wrong_caps_len, wrong_cls_id = \
                            sorted_bag

                imperfect_keys, imperfect_captions, imperfect_cap_lens, imperfect_fake_imgs, imgs, _ = \
                            sort_by_keys(imperfect_keys, imperfect_captions, imperfect_cap_lens, imperfect_fake_imgs, imgs,None)

                #-----------------------------------------------------------------------------------------------------------

                errD_total = 0
                D_logs = ''
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                              sent_emb, real_labels,
                                              fake_labels, words_embs,
                                              cap_lens, image_encoder,
                                              class_ids, w_words_embs,
                                              wrong_caps_len, wrong_cls_id)
                    # backward and update parameters
                    errD.backward(retain_graph=True)
                    optimizersD[i].step()
                    errD_total += errD
                    D_logs += 'errD%d: %.2f ' % (i, errD)

                step += 1
                gen_iterations += 1

                netG.zero_grad()
                errG_total, G_logs = \
                    generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                                   words_embs, sent_emb, match_labels, cap_lens, class_ids, style_loss, imgs)
                kl_loss = KL_loss(mu, logvar)

                errG_total += kl_loss

                G_logs += 'kl_loss: %.2f ' % kl_loss

                # Shift device for the imgs and target_imgs.-----------------------------------------------------
                for i in range(len(imgs)):
                    imgs[i] = imgs[i].to(secondary_device)
                    fake_imgs[i] = fake_imgs[i].to(secondary_device)

                # Compute and add ddva loss ---------------------------------------------------------------------
                neg_ddva = negative_ddva(imperfect_fake_imgs, imgs, fake_imgs)
                neg_ddva *= 10.  # Scale so that the ddva score is not overwhelmed by other losses.
                errG_total += neg_ddva.to(cfg.GPU_ID)
                G_logs += 'negative_ddva_loss: %.2f ' % neg_ddva
                #------------------------------------------------------------------------------------------------

                errG_total.backward()

                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                if gen_iterations % 100 == 0:
                    print(D_logs + '\n' + G_logs)

                # Copy parameters to the target network.
                if gen_iterations % 20 == 0:
                    load_params(target_netG, copy_G_params(netG))

            end_t = time.time()

            print('''[%d/%d][%d]
                  Loss_D: %.2f Loss_G: %.2f neg_ddva: %.2f Time: %.2fs''' %
                  (epoch, self.max_epoch, self.num_batches, errD_total,
                   errG_total, neg_ddva, end_t - start_t))

            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:
                self.save_model(netG, avg_param_G, netsD, epoch)

        self.save_model(netG, avg_param_G, netsD, self.max_epoch)
Esempio n. 19
0
    def train(self):
        text_encoder, image_encoder, netG, netsD, start_epoch, VGG = self.build_models(
        )
        avg_param_G = copy_G_params(netG)
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        gen_iterations = 0
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:

                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                data = data_iter.next()
                input_imgs_list, output_imgs_list, captions, cap_lens = prepare_data_LGIE(
                    data)

                hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef

                # matched text embeddings
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                # 不detach! 需要训练!
                # words_embs, sent_emb = words_embs.detach(), sent_emb.detach()

                # if not cfg.ANNO_PATH:
                #   # mismatched text embeddings
                #   w_words_embs, w_sent_emb = text_encoder(wrong_caps, wrong_caps_len, hidden)
                #   w_words_embs, w_sent_emb = w_words_embs.detach(), w_sent_emb.detach()

                # image features: regional and global
                region_features, cnn_code = image_encoder(
                    input_imgs_list[cfg.TREE.BRANCH_NUM - 1])

                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                #######################################################
                # (2) Modify real images
                ######################################################
                noise.data.normal_(0, 1)
                fake_imgs, _, mu, logvar, _, _ = netG(noise, sent_emb,
                                                      words_embs, mask,
                                                      cnn_code,
                                                      region_features)

                #######################################################
                # (3) Update D network
                ######################################################
                errD_total = 0
                if cfg.TRAIN.W_GAN:
                    D_logs = ''
                    for i in range(len(netsD)):
                        netsD[i].zero_grad()
                        errD = discriminator_loss(netsD[i], input_imgs_list[i],
                                                  fake_imgs[i], sent_emb,
                                                  real_labels, fake_labels)

                        # backward and update parameters
                        errD.backward(retain_graph=True)
                        optimizersD[i].step()
                        errD_total += errD
                        D_logs += 'errD%d: %.2f ' % (i, errD)

                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G
                step += 1
                gen_iterations += 1

                netG.zero_grad()
                errG_total, G_logs = generator_loss(netsD, image_encoder,
                                                    fake_imgs, real_labels,
                                                    words_embs, sent_emb, None,
                                                    None, None, VGG,
                                                    output_imgs_list)
                kl_loss = KL_loss(mu, logvar) * cfg.TRAIN.W_KL
                errG_total += kl_loss
                G_logs += 'kl_loss: %.2f ' % kl_loss
                # backward and update parameters
                errG_total.backward()
                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                if gen_iterations % 100 == 0:
                    if cfg.TRAIN.W_GAN:
                        print(D_logs + '\n' + G_logs)
                # save images
                if gen_iterations % 500 == 0:
                    backup_para = copy_G_params(netG)
                    load_params(netG, avg_param_G)
                    # self.save_img_results(netG, fixed_noise, sent_emb,
                    #                       words_embs, mask, image_encoder,
                    #                       captions, cap_lens, epoch, cnn_code,
                    #                       region_features, output_imgs_list, name='average')

                    # JWT_VIS
                    nvis = 5
                    input_img, output_img, fake_img = input_imgs_list[
                        -1], output_imgs_list[-1], fake_imgs[-1]
                    input_img, output_img, fake_img = self.tensor_to_numpy(
                        input_img), self.tensor_to_numpy(
                            output_img), self.tensor_to_numpy(fake_img)
                    # (b x h x w x c)
                    gap = 50
                    text_bg = np.zeros((gap, 256 * 3, 3))
                    res = np.zeros((1, 256 * 3, 3))
                    for vis_idx in range(nvis):
                        cur_input_img, cur_output_img, cur_fake_img = input_img[
                            vis_idx], output_img[vis_idx], fake_img[vis_idx]
                        row = np.concatenate(
                            [cur_input_img, cur_output_img, cur_fake_img],
                            1)  # (h, w * 3, 3)
                        row = np.concatenate([row, text_bg],
                                             0)  # (h+gap, w * 3, 3)

                        cur_cap = captions[vis_idx].data.cpu().numpy()
                        sentence = []
                        for cap_idx in range(len(cur_cap)):
                            if cur_cap[cap_idx] == 0:
                                break
                            word = self.ixtoword[cur_cap[cap_idx]].encode(
                                'ascii', 'ignore').decode('ascii')
                            sentence.append(word)
                        cv2.putText(row, ' '.join(sentence), (40, 256 + 10),
                                    cv2.FONT_HERSHEY_PLAIN, 1.2, (0, 0, 255),
                                    1)
                        res = np.concatenate([res, row], 0)

                    # finish and write image
                    cv2.imwrite(
                        os.path.join(self.image_dir,
                                     f'G_jwtvis_{gen_iterations}.png'), res)
                    load_params(netG, backup_para)

            end_t = time.time()

            print('''[%d/%d][%d]
                  Loss_D: %.2f Loss_G: %.2f Time: %.2fs''' %
                  (epoch, self.max_epoch, self.num_batches, errD_total,
                   errG_total, end_t - start_t))

            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:
                self.save_model(netG, avg_param_G, netsD, epoch, text_encoder,
                                image_encoder)

        self.save_model(netG, avg_param_G, netsD, self.max_epoch, text_encoder,
                        image_encoder)
    def train(self):
        text_encoder, image_encoder, caption_cnn, caption_rnn, netG, netsD, start_epoch = self.build_models(
        )
        avg_param_G = copy_G_params(netG)
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)

        # 1 in batch size for real label
        # 0 in batch size for fake label
        # 0-batch size for math labels
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM  # random vector noise dimension
        noise = Variable(torch.FloatTensor(batch_size,
                                           nz))  # batch_size * noise size
        fixed_noise = Variable(
            torch.FloatTensor(batch_size, nz).normal_(0, 1))  # same as before
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        gen_iterations = 0
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:
                # (1) Prepare training data and Compute text embeddings
                data = data_iter.next()
                # we already got the imgs, captions, cap_lens, class_ids, keys
                # what the prepare_data does is to send the data to CUDA and  sort the caption length
                imgs, captions, cap_lens, class_ids, keys = prepare_data(data)

                hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                # nef = num_hidden * num_directions
                # which means test_encoder sends captions to rnn and takes
                # all hidden output as word_embs and take the last hidden output as sent_emb
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
                # if no captions mask = True
                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                # (2) Generate fake images
                noise.data.normal_(0, 1)
                fake_imgs, _, mu, logvar = netG(
                    noise, sent_emb, words_embs, mask
                )  # get the mu, logvar from the CA augmentation, and fake image from the last layer of the generative net

                # (3) Update D network
                errD_total = 0
                D_logs = ''
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                              sent_emb, real_labels,
                                              fake_labels)
                    # backward and update parameters
                    errD.backward()
                    optimizersD[i].step()
                    errD_total += errD
                    D_logs += 'errD%d: %.2f ' % (i, errD.data)

                # (4) Update G network: maximize log(D(G(z)))
                # compute total loss for training G
                step += 1
                gen_iterations += 1
                netG.zero_grad()
                errG_total, G_logs = \
                    generator_loss(netsD, image_encoder, caption_cnn, caption_rnn, captions, fake_imgs, real_labels,
                                   words_embs, sent_emb, match_labels, cap_lens, class_ids)
                kl_loss = KL_loss(mu, logvar)
                errG_total += kl_loss
                G_logs += 'kl_loss: %.2f ' % kl_loss.data
                # backward and update parameters
                errG_total.backward()
                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                if gen_iterations % 100 == 0:
                    print(D_logs + '\n' + G_logs)
                # save images
                if gen_iterations % 1000 == 0:
                    backup_para = copy_G_params(netG)
                    load_params(netG, avg_param_G)
                    self.save_img_results(netG,
                                          fixed_noise,
                                          sent_emb,
                                          words_embs,
                                          mask,
                                          image_encoder,
                                          captions,
                                          cap_lens,
                                          epoch,
                                          name='average')
                    load_params(netG, backup_para)
            end_t = time.time()

            print('''[%d/%d][%d]
                  Loss_D: %.2f Loss_G: %.2f Time: %.2fs''' %
                  (epoch, self.max_epoch, self.num_batches, errD_total.data,
                   errG_total.data, end_t - start_t))

            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:  # and epoch != 0:
                self.save_model(netG, avg_param_G, netsD, epoch)

        self.save_model(netG, avg_param_G, netsD, self.max_epoch)
    def train(self):
        netsG, netsD, inception_model, classifiers, start_epoch = self.build_models(
        )
        avg_params_G = []
        for i in range(len(netsG)):
            avg_params_G.append(copy_G_params(netsG[i]))
        optimizersG, optimizersD, optimizersC = self.define_optimizers(
            netsG, netsD, classifiers)
        real_labels, fake_labels = self.prepare_labels()
        writer = SummaryWriter(self.args.run_dir)

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        gen_iterations = 0
        # gen_iterations = start_epoch * self.num_batches
        for epoch in range(start_epoch, self.max_epoch):
            print("epoch: {}/{}".format(epoch, self.max_epoch))
            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:
                print("step:{}/{} {:.2f}%".format(
                    step, self.num_batches, step / self.num_batches * 100))
                """
                if(step%self.display_interval==0):
                    print("step:{}/{} {:.2f}%".format(step, self.num_batches, step/self.num_batches*100))
                """
                # reset requires_grad to be trainable for all Ds
                # self.set_requires_grad_value(netsD, True)

                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                data = data_iter.next()
                real_imgs, atts, image_atts, class_ids, keys = prepare_data(
                    data)

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)

                print("before netG")
                fake_imgs = []
                C_losses = None
                if not self.args.kl_loss:
                    if cfg.TREE.BRANCH_NUM > 0:
                        fake_img1, h_code1 = nn.parallel.data_parallel(
                            netsG[0], (noise, atts, image_atts), self.gpus)
                        fake_imgs.append(fake_img1)
                        if self.args.split == 'train':  ##for train:
                            att_embeddings, C_losses = classifier_loss(
                                classifiers, inception_model, real_imgs[0],
                                image_atts, C_losses)
                            _, C_losses = classifier_loss(
                                classifiers, inception_model, fake_img1,
                                image_atts, C_losses)
                        else:
                            att_embeddings, _ = classifier_loss(
                                classifiers, inception_model, fake_img1,
                                image_atts)

                    if cfg.TREE.BRANCH_NUM > 1:
                        fake_img2, h_code2 = nn.parallel.data_parallel(
                            netsG[1], (h_code1, att_embeddings), self.gpus)
                        fake_imgs.append(fake_img2)
                        if self.args.split == 'train':
                            att_embeddings, C_losses = classifier_loss(
                                classifiers, inception_model, real_imgs[1],
                                image_atts, C_losses)
                            _, C_losses = classifier_loss(
                                classifiers, inception_model, fake_img1,
                                image_atts, C_losses)
                        else:
                            att_embeddings, _ = classifier_loss(
                                classifiers, inception_model, fake_img1,
                                image_atts)

                    if cfg.TREE.BRANCH_NUM > 2:
                        fake_img3 = nn.parallel.data_parallel(
                            netsG[2], (h_code2, att_embeddings), self.gpus)
                        fake_imgs.append(fake_img3)
                print("end netG")
                """
                if not self.args.kl_loss:
                    fake_imgs, C_losses = nn.parallel.data_parallel( netG, (noise, atts, image_atts,
                                                 inception_model, classifiers, imgs), self.gpus)
                else:
                    fake_imgs, mu, logvar = netG(noise, atts, image_atts) ## model内の次元が合っていない可能性。
                """

                #######################################################
                # (3) Update D network
                ######################################################
                errD_total = 0
                D_logs = ''
                errD_dic = {}
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    errD = discriminator_loss(netsD[i], real_imgs[i],
                                              fake_imgs[i], atts, real_labels,
                                              fake_labels)
                    # backward and update parameters
                    errD.backward()
                    optimizersD[i].step()
                    errD_total += errD
                    D_logs += 'errD%d: %.2f ' % (i, errD.item())
                    errD_dic['D_%d' % i] = errD.item()

                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G
                step += 1
                gen_iterations += 1

                # do not need to compute gradient for Ds
                # self.set_requires_grad_value(netsD, False)
                for i in range(len(netsG)):
                    netsG[i].zero_grad()
                print("before c backward")
                errC_total = 0
                C_logs = ''
                for i in range(len(classifiers)):
                    classifiers[i].zero_grad()
                    C_losses[i].backward(retain_graph=True)
                    optimizersC[i].step()
                    errC_total += C_losses[i]
                C_logs += 'errC_total: %.2f ' % (errC_total.item())
                print("end c backward")
                """
                for i,param in enumerate(netsG[0].parameters()):
                    if i==0:
                        print(param.grad)
                """

                ##TODO netGにgradientが溜まっているかどうかを確認せよ。

                errG_total = 0
                errG_total, G_logs, errG_dic = \
                    generator_loss(netsD, fake_imgs, real_labels, atts, errG_total)
                if self.args.kl_loss:
                    kl_loss = KL_loss(mu, logvar)
                    errG_total += kl_loss
                    G_logs += 'kl_loss: %.2f ' % kl_loss.item()
                    writer.add_scalar('kl_loss', kl_loss.item(),
                                      epoch * self.num_batches + step)

                # backward and update parameters
                errG_total.backward()
                for i in range(len(optimizersG)):
                    optimizersG[i].step()
                for i in range(len(optimizersC)):
                    optimizersC[i].step()

                errD_dic.update(errG_dic)
                writer.add_scalars('training_losses', errD_dic,
                                   epoch * self.num_batches + step)
                """
                self.save_img_results(netsG, fixed_noise, atts, image_atts, inception_model, 
                             classifiers, real_imgs, epoch, name='average') ##for debug
                """

                for i in range(len(netsG)):
                    for p, avg_p in zip(netsG[i].parameters(),
                                        avg_params_G[i]):
                        avg_p.mul_(0.999).add_(0.001, p.data)

                if gen_iterations % 100 == 0:
                    print(D_logs + '\n' + G_logs + '\n' + C_logs)
                # save images
                if gen_iterations % 1000 == 0:
                    backup_paras = []
                    for i in range(len(netsG)):
                        backup_para = copy_G_params(netsG[i])
                        backup_paras.append(backup_para)
                        load_params(netsG[i], avg_params_G[i])
                    self.save_img_results(netsG,
                                          fixed_noise,
                                          atts,
                                          image_atts,
                                          inception_model,
                                          classifiers,
                                          imgs,
                                          epoch,
                                          name='average')
                    for i in raneg(len(netsG)):
                        load_params(netsG[i], backup_paras[i])
                    #
                    # self.save_img_results(netG, fixed_noise, sent_emb,
                    #                       words_embs, mask, image_encoder,
                    #                       captions, cap_lens,
                    #                       epoch, name='current')
            end_t = time.time()

            print('''[%d/%d][%d]
                  Loss_D: %.2f Loss_G: %.2f Loss_C: %.2f Time: %.2fs''' %
                  (epoch, self.max_epoch, self.num_batches, errD_total.item(),
                   errG_total.item(), errC_total.item(), end_t - start_t))

            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:  # and epoch != 0:
                self.save_model(netsG, avg_params_G, netsD, classifiers, epoch)

        self.save_model(netsG, avg_params_G, netsD, classifiers,
                        self.max_epoch)
Esempio n. 22
0
    def train(self):
        text_encoder, image_encoder, caption_cnn, caption_rnn, netG, netsD, start_epoch = self.build_models()
        avg_param_G = copy_G_params(netG)
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        gen_iterations = 0
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:
                # (1) Prepare training data and Compute text embeddings
                data = data_iter.next()
                imgs, captions, cap_lens, class_ids, keys = prepare_data(data)

                hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
                mask = (captions == 0) + (captions == 1) + (captions == 2)  # masked <start>, <end>, <pad>
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                # (2) Generate fake images
                noise.data.normal_(0, 1)
                fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs, mask)

                # (3) Update D network
                errD_total = 0
                D_logs = ''
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                              sent_emb, real_labels, fake_labels)
                    # backward and update parameters
                    errD.backward()
                    optimizersD[i].step()
                    errD_total += errD
                    D_logs += 'errD%d: %.2f ' % (i, errD.data.item())

                # (4) Update G network: maximize log(D(G(z)))
                # compute total loss for training G
                step += 1
                gen_iterations += 1
                netG.zero_grad()
                errG_total, G_logs = \
                    generator_loss(netsD, image_encoder, caption_cnn, caption_rnn, captions, fake_imgs, real_labels,
                                   words_embs, sent_emb, match_labels, cap_lens, class_ids)
                kl_loss = KL_loss(mu, logvar)
                errG_total += kl_loss
                G_logs += 'kl_loss: %.2f ' % kl_loss.data.item()
                # backward and update parameters
                errG_total.backward()
                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                if gen_iterations % 100 == 0:
                    print(D_logs + '\n' + G_logs)
                # save images
                if gen_iterations % 1000 == 0:
                    print('Saving images...')
                    backup_para = copy_G_params(netG)
                    load_params(netG, avg_param_G)
                    self.save_img_results(netG, fixed_noise, sent_emb,
                                          words_embs, mask, image_encoder,
                                          captions, cap_lens, epoch, name='average')
                    load_params(netG, backup_para)
            end_t = time.time()

            print('''[%d/%d][%d]
                  Loss_D: %.2f Loss_G: %.2f Time: %.2fs'''
                  % (epoch, self.max_epoch, self.num_batches,
                     errD_total.data.item(), errG_total.data.item(),
                     end_t - start_t))

            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:  # and epoch != 0:
                self.save_model(netG, avg_param_G, netsD, epoch)

        self.save_model(netG, avg_param_G, netsD, self.max_epoch)
Esempio n. 23
0
    def train(self, model):
        text_encoder, image_encoder, netG, netsD, start_epoch = self.build_models(
        )  #load encoder
        avg_param_G = copy_G_params(netG)
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        gen_iterations = 0
        # gen_iterations = start_epoch * self.num_batches

        if cfg.TRAIN.CLIP_SENTENCODER:
            print("CLIP Sentence Encoder: True")

        if cfg.TRAIN.CLIP_LOSS:
            print("CLIP Loss: True")

        if cfg.TRAIN.EXTRA_LOSS:
            print("Extra DAMSM Loss in G: True")
            print("DAMSM Weight: ", cfg.TRAIN.WEIGHT_DAMSM_LOSS)

        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:
                # reset requires_grad to be trainable for all Ds
                # self.set_requires_grad_value(netsD, True)

                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                data = data_iter.next()

                # imgs, captions, cap_lens, class_ids, keys = prepare_data(data) #new sents:, sents
                # new: return raw texts
                imgs, captions, cap_lens, class_ids, keys, texts = prepare_data(
                    data)

                hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                # new: rename
                words_embs_damsm, sent_emb_damsm = text_encoder(
                    captions, cap_lens, hidden)
                #print('captions shape from trainer: ', captions.shape) torch.Size([12, 18])
                #print('sentence emb size: ', sent_emb.shape) torch.Size([12, 256])
                words_embs_damsm, sent_emb_damsm = words_embs_damsm.detach(
                ), sent_emb_damsm.detach()
                #print('sentence emb size after detach: ', sent_emb[0]) torch.Size([12, 256])
                mask = (captions == 0)
                num_words = words_embs_damsm.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                # new: use clip sentence encoder
                if cfg.TRAIN.CLIP_SENTENCODER or cfg.TRAIN.CLIP_LOSS:
                    sents = []
                    # randomly select one paragraph for each training example
                    for idx in range(len(texts)):
                        sents_per_image = texts[idx].split(
                            '\n')  #new: '\n' rather than '.'
                        if len(sents_per_image) > 1:
                            sent_ix = np.random.randint(
                                0,
                                len(sents_per_image) - 1)
                        else:
                            sent_ix = 0
                        sents.append(sents_per_image[sent_ix])
                    #print('sents: ', sents)

                    sent = clip.tokenize(sents)  #.to(device)

                    # load clip
                    #model = torch.jit.load("model.pt").cuda().eval()    # ViT-B/32
                    sent_input = sent.cuda()

                    with torch.no_grad():
                        sent_emb_clip = model.encode_text(sent_input).float()
                        if cfg.TRAIN.CLIP_SENTENCODER:
                            sent_emb = sent_emb_clip
                        else:
                            sent_emb = sent_emb_damsm
                else:
                    sent_emb_clip = 0
                    sent_emb = sent_emb_damsm

                words_embs = words_embs_damsm

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs,
                                                mask)

                #######################################################
                # (3) Update D network
                ######################################################
                errD_total = 0
                D_logs = ''
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                              sent_emb, real_labels,
                                              fake_labels)
                    # backward and update parameters
                    errD.backward()
                    optimizersD[i].step()
                    errD_total += errD
                    D_logs += 'errD%d: %.2f ' % (i, errD.item())

                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G
                step += 1
                gen_iterations += 1

                # do not need to compute gradient for Ds
                # self.set_requires_grad_value(netsD, False)
                netG.zero_grad()

                # new: pass clip model and sent_emb_damsm for CLIP_LOSS = True
                errG_total, G_logs = \
                    generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                                        words_embs, sent_emb, match_labels, cap_lens, class_ids, model, sent_emb_damsm, sent_emb_clip)

                kl_loss = KL_loss(mu, logvar)
                errG_total += kl_loss
                G_logs += 'kl_loss: %.2f ' % kl_loss.item()
                # backward and update parameters
                errG_total.backward()
                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                if gen_iterations % 100 == 0:
                    print(D_logs + '\n' + G_logs)
                # save images
                if gen_iterations % 1000 == 0:
                    backup_para = copy_G_params(netG)
                    load_params(netG, avg_param_G)
                    self.save_img_results(netG,
                                          fixed_noise,
                                          sent_emb,
                                          words_embs,
                                          mask,
                                          image_encoder,
                                          captions,
                                          cap_lens,
                                          epoch,
                                          name='average')
                    load_params(netG, backup_para)
                    #
                    # self.save_img_results(netG, fixed_noise, sent_emb,
                    #                       words_embs, mask, image_encoder,
                    #                       captions, cap_lens,
                    #                       epoch, name='current')
            end_t = time.time()

            print('''[%d/%d][%d]
                  Loss_D: %.2f Loss_G: %.2f Time: %.2fs''' %
                  (epoch, self.max_epoch, self.num_batches, errD_total.item(),
                   errG_total.item(), end_t - start_t))

            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0 or epoch % 10 == 0:  # and epoch != 0:
                self.save_model(netG, avg_param_G, netsD, epoch)

        self.save_model(netG, avg_param_G, netsD, self.max_epoch)
Esempio n. 24
0
    def train(self):
        text_encoder, netG, netD, start_epoch, VGG = self.build_models()
        avg_param_G = copy_G_params(netG)
        optimizerG, optimizerD = self.define_optimizers(netG, netD)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        gen_iterations = 0
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:

                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                data = data_iter.next()
                imgs, w_imgs, captions, cap_lens, class_ids, keys, wrong_caps, \
                                wrong_caps_len, wrong_cls_id, sorted_cap_indices, w_sorted_cap_indices = prepare_data(data)

                hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef

                # matched text embeddings
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()

                # mismatched text embeddings
                w_words_embs, w_sent_emb = text_encoder(
                    wrong_caps, wrong_caps_len, hidden)
                w_words_embs, w_sent_emb = w_words_embs.detach(
                ), w_sent_emb.detach()
                ### arenge w_words_embs asn w_sent_emb
                w_words_embs = self.reverse_indices(w_words_embs,
                                                    sorted_cap_indices,
                                                    w_sorted_cap_indices)
                w_sent_emb = self.reverse_indices(w_sent_emb,
                                                  sorted_cap_indices,
                                                  w_sorted_cap_indices)
                wrong_caps = self.reverse_indices(wrong_caps,
                                                  sorted_cap_indices,
                                                  w_sorted_cap_indices)
                wrong_caps_len = self.reverse_indices(wrong_caps_len,
                                                      sorted_cap_indices,
                                                      w_sorted_cap_indices)
                wrong_cls_id = self.reverse_indices(wrong_cls_id,
                                                    sorted_cap_indices,
                                                    w_sorted_cap_indices)
                # image features: regional and global
                #region_features, cnn_code = image_encoder(imgs[len(netsD)-1])

                mask = (captions == 0)  ##
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                #######################################################
                # (2) Modify real images
                ######################################################
                noise.data.normal_(0, 1)
                enc_features = VGG(imgs[-1])
                fake_img, mu, logvar = nn.parallel.data_parallel(
                    netG, (imgs[-1], sent_emb, words_embs, noise, mask,
                           enc_features), self.gpus)

                #######################################################
                # (3) Update D network
                ######################################################

                netD.zero_grad()
                errD, D_logs = discriminator_loss(netD, imgs[-1], fake_img,
                                                  sent_emb, w_sent_emb,
                                                  real_labels, fake_labels)
                errD.backward()
                optimizerD.step()
                """
                errD_total = 0
                D_logs = ''
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                              sent_emb, real_labels, fake_labels,
                                              words_embs, cap_lens, image_encoder, class_ids, w_words_embs, 
                                              wrong_caps_len, wrong_cls_id)
                    # backward and update parameters
                    errD.backward(retain_graph=True)
                    optimizersD[i].step()
                    errD_total += errD
                    D_logs += 'errD%d: %.2f ' % (i, errD)
                """

                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G
                step += 1
                gen_iterations += 1

                netG.zero_grad()
                errG_total, G_logs = \
                    generator_loss(netD, fake_img,  imgs[-1], w_imgs[-1], real_labels, sent_emb, VGG, self.gpus)
                kl_loss = KL_loss(mu, logvar)
                errG_total += kl_loss
                G_logs += 'kl_loss: %.2f ' % kl_loss
                # backward and update parameters
                errG_total.backward()
                optimizerG.step()

                #self.save_img_results(netG, fixed_noise, w_sent_emb,
                #                          w_words_embs, captions, wrong_caps, epoch, imgs, mask, VGG)

                #self.save_img_results(netG, fixed_noise, w_sent_emb,
                #                          w_words_embs, captions, wrong_caps, epoch, imgs)
                #self.save_model(netG, avg_param_G, netD, epoch)

                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                if gen_iterations % 100 == 0:
                    print(D_logs + '\n' + G_logs)
                # save images
                if gen_iterations % 1000 == 0:
                    backup_para = copy_G_params(netG)
                    load_params(netG, avg_param_G)
                    self.save_img_results(netG, fixed_noise, sent_emb,
                                          words_embs, w_sent_emb, w_words_embs,
                                          captions, wrong_caps, epoch, imgs,
                                          mask, VGG)
                    load_params(netG, backup_para)

            end_t = time.time()

            print('''[%d/%d][%d]
                  Loss_D: %.2f Loss_G: %.2f Time: %.2fs''' %
                  (epoch, self.max_epoch, self.num_batches, errD, errG_total,
                   end_t - start_t))

            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:
                self.save_model(netG, avg_param_G, netD, epoch)

        self.save_model(netG, avg_param_G, netD, self.max_epoch)
Esempio n. 25
0
    def train(self):
        text_encoder, image_encoder, netG, netsD, start_epoch = self.build_models()
        avg_param_G = copy_G_params(netG)
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        gen_iterations = 0
        # gen_iterations = start_epoch * self.num_batches
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:
                batch_t_begin = time.time()
                # reset requires_grad to be trainable for all Ds
                # self.set_requires_grad_value(netsD, True)

                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                data = data_iter.next()
                imgs, captions, cap_lens, class_ids, color_ids,sleeve_ids,gender_ids, keys = prepare_data(data)

                hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs, mask)

                #######################################################
                # (3) Update D network
                ######################################################
                errD_total = 0
                D_logs = ''
                D_logs_cls = ''
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    imgs[i] = gaussian_to_input(imgs[i]) ## INSTANCE NOISE
                    errD, cls_D= discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                              sent_emb, real_labels, fake_labels, class_ids, color_ids, sleeve_ids, gender_ids)
                    # backward and update parameters
                    errD_both = errD + cls_D/3.
                    # backward and update parameters
                    errD_both.backward()
                    optimizersD[i].step()
                    errD_total += errD
                    errD_total += cls_D/3.0
                    D_logs += 'errD%d: %.2f ' % (i, errD.data)
                    D_logs_cls += 'clsD%d: %.2f ' % (i, cls_D.data)

                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G
                step += 1
                gen_iterations += 1

                # do not need to compute gradient for Ds
                # self.set_requires_grad_value(netsD, False)
                netG.zero_grad()
                errG_total, G_logs = \
                    generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                                   words_embs, sent_emb, match_labels, cap_lens, class_ids, color_ids,sleeve_ids, gender_ids,imgs)
                kl_loss = KL_loss(mu, logvar)
                errG_total += kl_loss
                G_logs += 'kl_loss: %.2f ' % kl_loss.data
                # backward and update parameters
                errG_total.backward()
                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                if gen_iterations % 100 == 0:
                    batch_t_end = time.time()
                    print('| epoch {:3d} | {:5d}/{:5d} batches | batch_timer: {:5.2f} | '
                          .format(epoch, step, self.num_batches,
                                  batch_t_end - batch_t_begin,))
                    print(D_logs + '\n' + D_logs_cls + '\n' + G_logs)
                # save images
                if gen_iterations % 1000 == 0:
                    backup_para = copy_G_params(netG)
                    load_params(netG, avg_param_G)
                    self.save_img_results(netG, fixed_noise, sent_emb,
                                          words_embs, mask, image_encoder,
                                          captions, cap_lens, epoch, name='average')
                    load_params(netG, backup_para)
                    #
                    # self.save_img_results(netG, fixed_noise, sent_emb,
                    #                       words_embs, mask, image_encoder,
                    #                       captions, cap_lens,
                    #                       epoch, name='current')
            end_t = time.time()

            print('''[%d/%d][%d]
                  Loss_D: %.2f Loss_G: %.2f Time: %.2fs'''
                  % (epoch, self.max_epoch, self.num_batches,
                     errD_total.data, errG_total.data,
                     end_t - start_t))

            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:  # and epoch != 0:
                self.save_model(netG, avg_param_G, netsD, epoch)

        self.save_model(netG, avg_param_G, netsD, self.max_epoch)
Esempio n. 26
0
    def train(self):
        text_encoder, image_encoder, netG, netsD, start_epoch = self.build_models(
        )
        avg_param_G = copy_G_params(netG)
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        gen_iterations = 0
        # gen_iterations = start_epoch * self.num_batches

        errorD = []
        errorG = []
        loss_KL = []
        loss_s = []
        loss_w = []

        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:
                # reset requires_grad to be trainable for all Ds
                # self.set_requires_grad_value(netsD, True)

                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                data = data_iter.next()
                imgs, captions, cap_lens, class_ids, keys = prepare_data(data)

                hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs,
                                                mask)

                #######################################################
                # (3) Update D network
                ######################################################
                errD_total = 0
                D_logs = ''
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                              sent_emb, real_labels,
                                              fake_labels)
                    # backward and update parameters
                    errD.backward()
                    optimizersD[i].step()
                    errD_total += errD
                    D_logs += 'errD%d: %.2f ' % (i, errD.data)

                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G
                step += 1
                gen_iterations += 1

                # do not need to compute gradient for Ds
                # self.set_requires_grad_value(netsD, False)
                netG.zero_grad()
                errG_total, G_logs, w_loss, s_loss = \
                    generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                                   words_embs, sent_emb, match_labels, cap_lens, class_ids)
                kl_loss = KL_loss(mu, logvar)
                errG_total += kl_loss
                G_logs += 'kl_loss: %.2f ' % kl_loss.data
                # backward and update parameters
                errG_total.backward()
                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                if gen_iterations % 100 == 0:
                    print(D_logs + '\n' + G_logs)
                # save images
                if gen_iterations % 1000 == 0:
                    backup_para = copy_G_params(netG)
                    load_params(netG, avg_param_G)
                    self.save_img_results(netG,
                                          fixed_noise,
                                          sent_emb,
                                          words_embs,
                                          mask,
                                          image_encoder,
                                          captions,
                                          cap_lens,
                                          epoch,
                                          name='average')
                    load_params(netG, backup_para)
                    #
                    # self.save_img_results(netG, fixed_noise, sent_emb,
                    #                       words_embs, mask, image_encoder,
                    #                       captions, cap_lens,
                    #                       epoch, name='current')
            end_t = time.time()

            print('''[%d/%d][%d]
                  Loss_D: %.2f Loss_G: %.2f Time: %.2fs''' %
                  (epoch, self.max_epoch, self.num_batches, errD_total.data,
                   errG_total.data, end_t - start_t))

            errorD.append(errD_total)
            errorG.append(errG_total)
            loss_KL.append(kl_loss)
            loss_s.append(s_loss)
            loss_w.append(w_loss)

            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:  # and epoch != 0:
                self.save_model(netG, avg_param_G, netsD, epoch)

        plt.plot(errorG, label="Generator Loss")
        plt.plot(errorD, label="Discriminator Loss")
        plt.legend()
        plt.title("loss function for each epoch")
        plt.show()

        plt.plot(loss_KL, label="KL Loss")
        plt.title("KL loss function")
        plt.show()

        plt.plot(loss_s, label="sent Loss")
        plt.plot(loss_w, label="word Loss")
        plt.legend()
        plt.title("specfic loss function in generator")
        plt.show()

        self.save_model(netG, avg_param_G, netsD, self.max_epoch)
Esempio n. 27
0
    def train(self):
        torch.autograd.set_detect_anomaly(True)

        text_encoder, image_encoder, netG, netsD, start_epoch = self.build_models()
        avg_param_G = copy_G_params(netG)
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        if cfg.TRAIN.OPTIMIZE_DATA_LOADING:
            batch_sizes = self.batch_size
            noise, local_noise, fixed_noise = [], [], []
            for batch_size in batch_sizes:
                noise.append(Variable(torch.FloatTensor(batch_size, cfg.GAN.GLOBAL_Z_DIM)).to(cfg.DEVICE))
                local_noise.append(Variable(torch.FloatTensor(batch_size, cfg.GAN.LOCAL_Z_DIM)).to(cfg.DEVICE))
                fixed_noise.append(Variable(torch.FloatTensor(batch_size, cfg.GAN.GLOBAL_Z_DIM).normal_(0, 1)).to(cfg.DEVICE))
        else:
            batch_size = self.batch_size[0]
            noise = Variable(torch.FloatTensor(batch_size, cfg.GAN.GLOBAL_Z_DIM)).to(cfg.DEVICE)
            local_noise = Variable(torch.FloatTensor(batch_size, cfg.GAN.LOCAL_Z_DIM)).to(cfg.DEVICE)
            fixed_noise = Variable(torch.FloatTensor(batch_size, cfg.GAN.GLOBAL_Z_DIM).normal_(0, 1)).to(cfg.DEVICE)

        for epoch in range(start_epoch, self.max_epoch):
            logger.info("Epoch nb: %s" % epoch)
            gen_iterations = 0
            if cfg.TRAIN.OPTIMIZE_DATA_LOADING:
                data_iter = []
                for _idx in range(len(self.data_loader)):
                    data_iter.append(iter(self.data_loader[_idx]))
                total_batches_left = sum([len(self.data_loader[i]) for i in range(len(self.data_loader))])
                current_probability = [len(self.data_loader[i]) for i in range(len(self.data_loader))]
                current_probability_percent = [current_probability[i] / float(total_batches_left) for i in
                                               range(len(current_probability))]
            else:
                data_iter = iter(self.data_loader)

            _dataset = tqdm(range(self.num_batches))
            for step in _dataset:
                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                if cfg.TRAIN.OPTIMIZE_DATA_LOADING:
                    subset_idx = np.random.choice(range(len(self.data_loader)), size=None,
                                                  p=current_probability_percent)
                    total_batches_left -= 1
                    if total_batches_left > 0:
                        current_probability[subset_idx] -= 1
                        current_probability_percent = [current_probability[i] / float(total_batches_left) for i in
                                                       range(len(current_probability))]

                    max_objects = subset_idx
                    data = data_iter[subset_idx].next()
                else:
                    data = data_iter.next()
                    max_objects = 3
                _dataset.set_description('Obj-{}'.format(max_objects))

                imgs, captions, cap_lens, class_ids, keys, transformation_matrices, label_one_hot = prepare_data(data)
                transf_matrices = transformation_matrices[0]
                transf_matrices_inv = transformation_matrices[1]

                with torch.no_grad():
                    if cfg.TRAIN.OPTIMIZE_DATA_LOADING:
                        hidden = text_encoder.init_hidden(batch_sizes[subset_idx])
                    else:
                        hidden = text_encoder.init_hidden(batch_size)
                    # words_embs: batch_size x nef x seq_len
                    # sent_emb: batch_size x nef
                    words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                    words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
                    mask = (captions == 0).bool()
                    num_words = words_embs.size(2)
                    if mask.size(1) > num_words:
                        mask = mask[:, :num_words]

                #######################################################
                # (2) Generate fake images
                ######################################################
                if cfg.TRAIN.OPTIMIZE_DATA_LOADING:
                    noise[subset_idx].data.normal_(0, 1)
                    local_noise[subset_idx].data.normal_(0, 1)
                    inputs = (noise[subset_idx], local_noise[subset_idx], sent_emb, words_embs, mask, transf_matrices,
                              transf_matrices_inv, label_one_hot, max_objects)
                else:
                    noise.data.normal_(0, 1)
                    local_noise.data.normal_(0, 1)
                    inputs = (noise, local_noise, sent_emb, words_embs, mask, transf_matrices, transf_matrices_inv,
                              label_one_hot, max_objects)

                inputs = tuple((inp.to(cfg.DEVICE) if isinstance(inp, torch.Tensor) else inp) for inp in inputs)
                fake_imgs, _, mu, logvar = netG(*inputs)

                #######################################################
                # (3) Update D network
                ######################################################
                # errD_total = 0
                D_logs = ''
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    if cfg.TRAIN.OPTIMIZE_DATA_LOADING:
                        errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                                  sent_emb, real_labels[subset_idx], fake_labels[subset_idx],
                                                  local_labels=label_one_hot, transf_matrices=transf_matrices,
                                                  transf_matrices_inv=transf_matrices_inv, cfg=cfg,
                                                  max_objects=max_objects)
                    else:
                        errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                                  sent_emb, real_labels, fake_labels,
                                                  local_labels=label_one_hot, transf_matrices=transf_matrices,
                                                  transf_matrices_inv=transf_matrices_inv, cfg=cfg,
                                                  max_objects=max_objects)

                    # backward and update parameters
                    errD.backward()
                    optimizersD[i].step()
                    D_logs += 'errD%d: %.2f ' % (i, errD.item())

                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G
                # step += 1
                gen_iterations += 1

                # do not need to compute gradient for Ds
                netG.zero_grad()
                if cfg.TRAIN.OPTIMIZE_DATA_LOADING:
                    errG_total = \
                        generator_loss(netsD, image_encoder, fake_imgs, real_labels[subset_idx],
                                       words_embs, sent_emb, match_labels[subset_idx], cap_lens, class_ids,
                                       local_labels=label_one_hot, transf_matrices=transf_matrices,
                                       transf_matrices_inv=transf_matrices_inv, max_objects=max_objects)
                else:
                    errG_total = \
                        generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                                       words_embs, sent_emb, match_labels, cap_lens, class_ids,
                                       local_labels=label_one_hot, transf_matrices=transf_matrices,
                                       transf_matrices_inv=transf_matrices_inv, max_objects=max_objects)
                kl_loss = KL_loss(mu, logvar)
                errG_total += kl_loss
                # backward and update parameters
                errG_total.backward()
                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(p.data, alpha=0.001)

                if cfg.TRAIN.EMPTY_CACHE:
                    torch.cuda.empty_cache()

                # save images
                if (
                        2 * gen_iterations == self.num_batches
                        or 2 * gen_iterations + 1 == self.num_batches
                        or gen_iterations + 1 == self.num_batches
                ):
                    logger.info('Saving images...')
                    backup_para = copy_G_params(netG)
                    load_params(netG, avg_param_G)
                    if cfg.TRAIN.OPTIMIZE_DATA_LOADING:
                        self.save_img_results(netG, fixed_noise[subset_idx], sent_emb,
                                              words_embs, mask, image_encoder,
                                              captions, cap_lens, epoch, transf_matrices_inv,
                                              label_one_hot, local_noise[subset_idx], transf_matrices,
                                          max_objects, subset_idx, name='average')
                    else:
                        self.save_img_results(netG, fixed_noise, sent_emb,
                                          words_embs, mask, image_encoder,
                                          captions, cap_lens, epoch, transf_matrices_inv,
                                          label_one_hot, local_noise, transf_matrices,
                                          max_objects, None, name='average')
                    load_params(netG, backup_para)

            self.save_model(netG, avg_param_G, netsD, optimizerG, optimizersD, epoch)
        self.save_model(netG, avg_param_G, netsD, optimizerG, optimizersD, epoch)
Esempio n. 28
0
    def train(self):

        now = datetime.datetime.now(dateutil.tz.tzlocal())
        timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')

        tb_dir = '../tensorboard/{0}_{1}_{2}'.format(cfg.DATASET_NAME,
                                                     cfg.CONFIG_NAME,
                                                     timestamp)
        mkdir_p(tb_dir)
        tbw = SummaryWriter(log_dir=tb_dir)  # Tensorboard logging

        text_encoder, image_encoder, netG, netsD, start_epoch, cap_model = self.build_models(
        )
        labels = Variable(torch.LongTensor(range(
            self.batch_size)))  # used for matching loss

        text_encoder.train()
        image_encoder.train()
        for k, v in image_encoder.named_children(
        ):  # set the input layer1-5 not training and no grads.
            if k in frozen_list_image_encoder:
                v.training = False
                v.requires_grad_(False)
        netG.train()
        for i in range(len(netsD)):
            netsD[i].train()
        cap_model.train()

        avg_param_G = copy_G_params(netG)
        optimizerI, optimizerT, optimizerG , optimizersD , optimizerC , lr_schedulerC \
        , lr_schedulerI , lr_schedulerT = self.define_optimizers(image_encoder
                                                                , text_encoder
                                                                , netG
                                                                , netsD
                                                                , cap_model)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))

        cap_criterion = torch.nn.CrossEntropyLoss(
        )  # add caption criterion here
        if cfg.CUDA:
            labels = labels.cuda()
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
            cap_criterion = cap_criterion.cuda()  # add caption criterion here
        cap_criterion.train()

        gen_iterations = 0
        # gen_iterations = start_epoch * self.num_batches
        for epoch in range(start_epoch, self.max_epoch):

            ##### set everything to trainable ####
            text_encoder.train()
            image_encoder.train()
            netG.train()
            cap_model.train()
            for k, v in image_encoder.named_children():
                if k in frozen_list_image_encoder:
                    v.train(False)
            for i in range(len(netsD)):
                netsD[i].train()
            ##### set everything to trainable ####

            fi_w_total_loss0 = 0
            fi_w_total_loss1 = 0
            fi_s_total_loss0 = 0
            fi_s_total_loss1 = 0
            ft_w_total_loss0 = 0
            ft_w_total_loss1 = 0
            ft_s_total_loss0 = 0
            ft_s_total_loss1 = 0
            s_total_loss0 = 0
            s_total_loss1 = 0
            w_total_loss0 = 0
            w_total_loss1 = 0
            c_total_loss = 0

            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:
                print('step:{:6d}|{:3d}'.format(step, self.num_batches),
                      end='\r')
                # reset requires_grad to be trainable for all Ds
                # self.set_requires_grad_value(netsD, True)

                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                data = data_iter.next()
                # add images, image masks, captions, caption masks for catr model
                imgs, captions, cap_lens, class_ids, keys, cap_imgs, cap_img_masks, sentences, sent_masks = prepare_data(
                    data)

                ################## feedforward damsm model ##################
                image_encoder.zero_grad()  # image/text encoders zero_grad here
                text_encoder.zero_grad()

                words_features, sent_code = image_encoder(
                    cap_imgs
                )  # input catr images to image encoder, feedforward, Nx256x17x17
                #                 words_features, sent_code = image_encoder(imgs[-1]) # input image_encoder
                nef, att_sze = words_features.size(1), words_features.size(2)
                # hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                words_embs, sent_emb = text_encoder(
                    captions)  #, cap_lens, hidden)

                #### damsm losses
                w_loss0, w_loss1, attn_maps = words_loss(
                    words_features, words_embs, labels, cap_lens, class_ids,
                    batch_size)
                w_total_loss0 += w_loss0.data
                w_total_loss1 += w_loss1.data
                damsm_loss = w_loss0 + w_loss1

                s_loss0, s_loss1 = sent_loss(sent_code, sent_emb, labels,
                                             class_ids, batch_size)
                s_total_loss0 += s_loss0.data
                s_total_loss1 += s_loss1.data
                damsm_loss += s_loss0 + s_loss1

                #                 damsm_loss.backward()

                #                 words_features = words_features.detach()
                # real image real text matching loss graph cleared here
                # grad accumulated -> text_encoder
                #                  -> image_encoder
                #################################################################################

                ################## feedforward image encoder and caption model ##################
                #                 words_features, sent_code = image_encoder(cap_imgs)
                cap_model.zero_grad()  # caption model zero_grad here

                cap_preds = cap_model(
                    words_features, cap_img_masks, sentences[:, :-1],
                    sent_masks[:, :-1])  # caption model feedforward
                cap_loss = caption_loss(cap_criterion, cap_preds, sentences)
                c_total_loss += cap_loss.data
                #                 cap_loss.backward() # caption loss graph cleared,
                # grad accumulated -> cap_model -> image_encoder
                torch.nn.utils.clip_grad_norm_(cap_model.parameters(),
                                               config.clip_max_norm)
                #                 optimizerC.step() # update cap_model params
                #################################################################################

                ############ Prepare the input to Gan from the output of text_encoder ################
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()

                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]
                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs,
                                                mask)

                #                 f_img = np.asarray(fake_imgs[-1].permute((0,2,3,1)).detach().cpu())
                #                 print('fake_imgs.size():{0},fake_imgs.min():{1},fake_imgs.max():{2}'.format(fake_imgs[-1].size()
                #                                   ,fake_imgs[-1].min()
                #                                   ,fake_imgs[-1].max()))

                #                 print('f_img.shape:{0}'.format(f_img.shape))

                #######################################################
                # (3) Update D network
                ######################################################
                errD_total = 0
                D_logs = ''
                for i in range(len(netsD)):
                    print(i)
                    netsD[i].zero_grad()
                    errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                              sent_emb, real_labels,
                                              fake_labels)
                    # backward and update parameters
                    errD.backward()
                    optimizersD[i].step()
                    errD_total += errD
                    D_logs += 'errD%d: %.2f ' % (i, errD.data)

                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G
                step += 1
                gen_iterations += 1

                # do not need to compute gradient for Ds
                # self.set_requires_grad_value(netsD, False)
                netG.zero_grad()
                errG_total, G_logs = \
                    generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                                   words_embs, sent_emb, match_labels, cap_lens, class_ids)
                kl_loss = KL_loss(mu, logvar)
                errG_total += kl_loss
                G_logs += 'kl_loss: %.2f ' % kl_loss.data
                # backward and update parameters
                errG_total.backward()
                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                # 14 -- 2800 iterations=steps for 1 epoch
                if gen_iterations % 100 == 0:
                    print(D_logs + '\n' + G_logs)
                # save images
                if gen_iterations % 1000 == 0:
                    backup_para = copy_G_params(netG)
                    load_params(netG, avg_param_G)
                    self.save_img_results(netG,
                                          fixed_noise,
                                          sent_emb,
                                          words_embs,
                                          mask,
                                          image_encoder,
                                          captions,
                                          cap_lens,
                                          epoch,
                                          name='average')
                    load_params(netG, backup_para)
                    #
                    # self.save_img_results(netG, fixed_noise, sent_emb,
                    #                       words_embs, mask, image_encoder,
                    #                       captions, cap_lens,
                    #                       epoch, name='current')

                #### temporary check ####
#                 if step == 5:
#                     break
#                 print('fake_img shape:',fake_imgs[-1].size())

#                 # this is fine #### exists in GeneratorLoss
#                 fake_imgs[-1] = fake_imgs[-1].detach()
#                 ####### fake imge real text matching loss #################
#                 fi_word_features, fi_sent_code = image_encoder(fake_imgs[-1])
# #                 words_embs, sent_emb = text_encoder(captions) # to update the text

#                 fi_w_loss0, fi_w_loss1, fi_attn_maps = words_loss(fi_word_features, words_embs, labels,
#                                                  cap_lens, class_ids, batch_size)

#                 fi_w_total_loss0 += fi_w_loss0.data
#                 fi_w_total_loss1 += fi_w_loss1.data

#                 fi_damsm_loss = fi_w_loss0 + fi_w_loss1

#                 fi_s_loss0, fi_s_loss1 = sent_loss(fi_sent_code, sent_emb, labels, class_ids, batch_size)

#                 fi_s_total_loss0 += fi_s_loss0.data
#                 fi_s_total_loss1 += fi_s_loss1.data

#                 fi_damsm_loss += fi_s_loss0 + fi_s_loss1

#                 fi_damsm_loss.backward()

###### real image fake text matching loss ##############

                fake_preds = torch.argmax(cap_preds,
                                          axis=-1)  # capation predictions
                fake_captions = tokenizer.batch_decode(
                    fake_preds.tolist(),
                    skip_special_tokens=True)  # list of strings
                fake_outputs = retokenizer.batch_encode_plus(
                    fake_captions,
                    max_length=64,
                    padding='max_length',
                    add_special_tokens=False,
                    return_attention_mask=True,
                    return_token_type_ids=False,
                    truncation=True)
                fake_tokens = fake_outputs['input_ids']
                #                 fake_tkmask = fake_outputs['attention_mask']
                f_tokens = np.zeros((len(fake_tokens), 15), dtype=np.int64)
                f_cap_lens = []
                cnt = 0
                for i in fake_tokens:
                    temp = np.array([x for x in i if x != 27299 and x != 0])
                    num_words = len(temp)
                    if num_words <= 15:
                        f_tokens[cnt][:num_words] = temp
                    else:
                        ix = list(np.arange(num_words))  # 1, 2, 3,..., maxNum
                        np.random.shuffle(ix)
                        ix = ix[:15]
                        ix = np.sort(ix)
                        f_tokens[cnt] = temp[ix]
                        num_words = 15
                    f_cap_lens.append(num_words)
                    cnt += 1

                f_tokens = Variable(torch.tensor(f_tokens))
                f_cap_lens = Variable(torch.tensor(f_cap_lens))
                if cfg.CUDA:
                    f_tokens = f_tokens.cuda()
                    f_cap_lens = f_cap_lens.cuda()

                ft_words_emb, ft_sent_emb = text_encoder(
                    f_tokens)  # input text_encoder

                ft_w_loss0, ft_w_loss1, ft_attn_maps = words_loss(
                    words_features, ft_words_emb, labels, f_cap_lens,
                    class_ids, batch_size)

                ft_w_total_loss0 += ft_w_loss0.data
                ft_w_total_loss1 += ft_w_loss1.data

                ft_damsm_loss = ft_w_loss0 + ft_w_loss1

                ft_s_loss0, ft_s_loss1 = sent_loss(sent_code, ft_sent_emb,
                                                   labels, class_ids,
                                                   batch_size)

                ft_s_total_loss0 += ft_s_loss0.data
                ft_s_total_loss1 += ft_s_loss1.data

                ft_damsm_loss += ft_s_loss0 + ft_s_loss1

                #                 ft_damsm_loss.backward()

                total_multimodal_loss = damsm_loss + ft_damsm_loss + cap_loss
                total_multimodal_loss.backward()
                ## loss = 0.5*loss1 + 0.4*loss2 + ...
                ## loss.backward() -> accumulate grad value in parameters.grad

                ## loss1 = 0.5*loss1
                ## loss1.backward()

                torch.nn.utils.clip_grad_norm_(image_encoder.parameters(),
                                               cfg.TRAIN.RNN_GRAD_CLIP)

                optimizerI.step()

                torch.nn.utils.clip_grad_norm_(text_encoder.parameters(),
                                               cfg.TRAIN.RNN_GRAD_CLIP)
                optimizerT.step()

                optimizerC.step()  # update cap_model params

            lr_schedulerC.step()
            lr_schedulerI.step()
            lr_schedulerT.step()

            end_t = time.time()

            tbw.add_scalar('Loss_D', float(errD_total.item()), epoch)
            tbw.add_scalar('Loss_G', float(errG_total.item()), epoch)
            tbw.add_scalar('train_w_loss0', float(w_total_loss0.item()), epoch)
            tbw.add_scalar('train_s_loss0', float(s_total_loss0.item()), epoch)
            tbw.add_scalar('train_w_loss1', float(w_total_loss1.item()), epoch)
            tbw.add_scalar('train_s_loss1', float(s_total_loss1.item()), epoch)
            tbw.add_scalar('train_c_loss', float(c_total_loss.item()), epoch)

            print('''[%d/%d][%d]
                  Loss_D: %.2f Loss_G: %.2f Time: %.2fs''' %
                  (epoch, self.max_epoch, self.num_batches, errD_total.data,
                   errG_total.data, end_t - start_t))

            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:  # and epoch != 0:

                self.save_model(netG, avg_param_G, image_encoder, text_encoder,
                                netsD, epoch, cap_model, optimizerC,
                                optimizerI, optimizerT, lr_schedulerC,
                                lr_schedulerI, lr_schedulerT)

            v_s_cur_loss, v_w_cur_loss, v_c_cur_loss = self.evaluate(
                self.dataloader_val, image_encoder, text_encoder, cap_model,
                self.batch_size)
            print(
                'v_s_cur_loss:{:.5f}, v_w_cur_loss:{:.5f}, v_c_cur_loss:{:.5f}'
                .format(v_s_cur_loss, v_w_cur_loss, v_c_cur_loss))
            tbw.add_scalar('val_w_loss', float(v_w_cur_loss), epoch)
            tbw.add_scalar('val_s_loss', float(v_s_cur_loss), epoch)
            tbw.add_scalar('val_c_loss', float(v_c_cur_loss), epoch)

        self.save_model(netG, avg_param_G, image_encoder, text_encoder, netsD,
                        self.max_epoch, cap_model, optimizerC, optimizerI,
                        optimizerT, lr_schedulerC, lr_schedulerI,
                        lr_schedulerT)
Esempio n. 29
0
    def train(self):
        text_encoder, image_encoder, netG, netsD, start_epoch = self.build_models()
        avg_param_G = copy_G_params(netG)
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        gen_iterations = 0
        # gen_iterations = start_epoch * self.num_batches
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:
                # reset requires_grad to be trainable for all Ds
                # self.set_requires_grad_value(netsD, True)

                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                data = data_iter.next()
                imgs, captions, cap_lens, class_ids, keys, transformation_matrices, label_one_hot = prepare_data(data)
                transf_matrices = transformation_matrices[0]
                transf_matrices_inv = transformation_matrices[1]

                hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                inputs = (noise, sent_emb, words_embs, mask, transf_matrices_inv, label_one_hot)
                fake_imgs, _, mu, logvar = nn.parallel.data_parallel(netG, inputs, self.gpus)

                #######################################################
                # (3) Update D network
                ######################################################
                errD_total = 0
                D_logs = ''
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    if i == 0:
                        errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                                  sent_emb, real_labels, fake_labels, self.gpus,
                                                  local_labels=label_one_hot, transf_matrices=transf_matrices,
                                                  transf_matrices_inv=transf_matrices_inv)
                    else:
                        errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                                  sent_emb, real_labels, fake_labels, self.gpus)

                    # backward and update parameters
                    errD.backward()
                    optimizersD[i].step()
                    errD_total += errD
                    D_logs += 'errD%d: %.2f ' % (i, errD.item())

                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G
                step += 1
                gen_iterations += 1

                # do not need to compute gradient for Ds
                # self.set_requires_grad_value(netsD, False)
                netG.zero_grad()
                errG_total, G_logs = \
                    generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                                   words_embs, sent_emb, match_labels, cap_lens, class_ids, self.gpus,
                                   local_labels=label_one_hot, transf_matrices=transf_matrices,
                                   transf_matrices_inv=transf_matrices_inv)
                kl_loss = KL_loss(mu, logvar)
                errG_total += kl_loss
                G_logs += 'kl_loss: %.2f ' % kl_loss.item()
                # backward and update parameters
                errG_total.backward()
                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                # save images
                if gen_iterations % 1000 == 0:
                    print(D_logs + '\n' + G_logs)

                    backup_para = copy_G_params(netG)
                    load_params(netG, avg_param_G)
                    self.save_img_results(netG, fixed_noise, sent_emb,
                                          words_embs, mask, image_encoder,
                                          captions, cap_lens, epoch,  transf_matrices_inv,
                                          label_one_hot, name='average')
                    load_params(netG, backup_para)
            end_t = time.time()

            print('''[%d/%d][%d]
                  Loss_D: %.2f Loss_G: %.2f Time: %.2fs'''
                  % (epoch, self.max_epoch, self.num_batches,
                     errD_total.item(), errG_total.item(),
                     end_t - start_t))

            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:  # and epoch != 0:
                self.save_model(netG, avg_param_G, netsD, optimizerG, optimizersD, epoch)

        self.save_model(netG, avg_param_G, netsD, optimizerG, optimizersD, epoch)
Esempio n. 30
0
    def train(self):
        netG, netINSD, netGLBD, start_epoch = self.build_models()
        avg_param_G = copy_G_params(netG)

        batch_size = self.batch_size
        noise = Variable(
            torch.FloatTensor(batch_size, cfg.ROI.BOXES_NUM,
                              len(self.cats_index_dict) * 4))
        fixed_noise = Variable(
            torch.FloatTensor(batch_size, cfg.ROI.BOXES_NUM,
                              len(self.cats_index_dict) * 4).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
        gen_iterations = 0
        lr_rate = 1
        pcp_score = 0.
        # gen_iterations = start_epoch * self.num_batches
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()
            if epoch > 50 and lr_rate > cfg.TRAIN.GENERATOR_LR / 10.:
                lr_rate *= 0.98
            optimizerG, optimizerINSD, optimizerGLBD = self.define_optimizers(
                netG, netINSD, netGLBD, lr_rate)
            data_iter = iter(self.data_loader)
            step = 0

            while step < self.num_batches:
                #print('step: ', step)
                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                data = data_iter.next()
                imgs, pooled_hmaps, hmaps, bbox_maps_fwd, bbox_maps_bwd, bbox_fmaps, \
                    rois, fm_rois, num_rois, class_ids, keys = prepare_data(data)

                #######################################################
                # (2) Generate fake images
                ######################################################
                max_num_roi = int(torch.max(num_rois))
                noise.data.normal_(0, 1)
                fake_hmaps = netG(noise[:, :max_num_roi], bbox_maps_fwd,
                                  bbox_maps_bwd, bbox_fmaps)

                #######################################################
                # (3-1) Update INSD network
                ######################################################
                errINSD = 0
                netINSD.zero_grad()
                errINSD = ins_discriminator_loss(netINSD, hmaps, fake_hmaps,
                                                 bbox_maps_fwd)
                errINSD.backward()
                optimizerINSD.step()
                INSD_logs = 'errINSD: %.2f ' % (errINSD.item())

                #######################################################
                # (3-2) Update GLBD network
                ######################################################
                errGLBD = 0
                netGLBD.zero_grad()
                errGLBD = glb_discriminator_loss(netGLBD, pooled_hmaps,
                                                 fake_hmaps, bbox_maps_fwd)
                errGLBD.backward()
                optimizerGLBD.step()
                GLBD_logs = 'errGLBD: %.2f ' % (errGLBD.item())

                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G
                step += 1
                gen_iterations += 1

                netG.zero_grad()
                errG_total, G_logs, item_pcp_score = generator_loss(
                    netINSD, netGLBD, self.vgg_model, hmaps, fake_hmaps,
                    bbox_maps_fwd)
                pcp_score += item_pcp_score

                errG_total.backward()
                # `clip_grad_norm` helps prevent
                # the exploding gradient problem in RNNs / LSTMs.
                torch.nn.utils.clip_grad_norm_(netG.parameters(),
                                               cfg.TRAIN.RNN_GRAD_CLIP)
                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                if gen_iterations % self.print_interval == 0:
                    elapsed = time.time() - start_t
                    print(
                        '| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | '
                        .format(epoch, step, self.num_batches,
                                elapsed * 1000. / self.print_interval))
                    print(INSD_logs + '\n' + GLBD_logs + '\n' + G_logs)
                    start_t = time.time()

                # save images
                if gen_iterations % self.display_interval == 0:
                    backup_para = copy_G_params(netG)
                    load_params(netG, avg_param_G)
                    self.save_img_results(netG,
                                          fixed_noise[:, :max_num_roi],
                                          imgs,
                                          bbox_maps_fwd,
                                          bbox_maps_bwd,
                                          bbox_fmaps,
                                          hmaps,
                                          rois,
                                          num_rois,
                                          gen_iterations,
                                          name='average')
                    load_params(netG, backup_para)

            pcp_score /= float(self.num_batches)
            print('pcp_score: ', pcp_score)
            fullpath = '%s/scores_%d.txt' % (self.score_dir, epoch)
            with open(fullpath, 'w') as fp:
                fp.write('pcp_score %f' % (pcp_score))

            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:  # and epoch != 0:
                self.save_model(netG, avg_param_G, netINSD, netGLBD, epoch)

        self.save_model(netG, avg_param_G, netINSD, netGLBD, self.max_epoch)