示例#1
0
    def save_model(self, netDCM, avg_param_C, netD, epoch):
        backup_para = copy_G_params(netDCM)
        load_params(netDCM, avg_param_C)
        torch.save(netDCM.state_dict(),
                   '%s/netC_epoch_%d.pth' % (self.model_dir, epoch))
        load_params(netDCM, backup_para)

        torch.save(netD.state_dict(),
                   '%s/netD_epoch_%d.pth' % (self.model_dir, epoch))

        print('Save C/D models.')
示例#2
0
 def save_model(self, netG, avg_param_G, netsD, epoch):
     backup_para = copy_G_params(netG)
     load_params(netG, avg_param_G)
     torch.save(netG.state_dict(),
                '%s/netG_epoch_%d.pth' % (self.model_dir, epoch))
     load_params(netG, backup_para)
     #
     for i in range(len(netsD)):
         netD = netsD[i]
         torch.save(netD.state_dict(),
                    '%s/netD%d.pth' % (self.model_dir, i))
     print('Save G/Ds models.')
示例#3
0
    def train(self):
        text_encoder, image_encoder, netG, netD, start_epoch, VGG, netDCM = self.build_models(
        )
        avg_param_C = copy_G_params(netDCM)
        optimizerC, optimizerD = self.define_optimizers(netDCM, netD)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        gen_iterations = 0
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:

                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                data = data_iter.next()
                imgs, captions, cap_lens, class_ids, keys, wrong_caps, \
                                wrong_caps_len, wrong_cls_id = prepare_data(data)

                hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef

                # matched text embeddings
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()

                # mismatched text embeddings
                w_words_embs, w_sent_emb = text_encoder(
                    wrong_caps, wrong_caps_len, hidden)
                w_words_embs, w_sent_emb = w_words_embs.detach(
                ), w_sent_emb.detach()

                # image embeddings: regional and global
                region_features, cnn_code = image_encoder(
                    imgs[cfg.TREE.BRANCH_NUM - 1])

                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                #######################################################
                # (2) Modify real images
                ######################################################
                noise.data.normal_(0, 1)
                fake_imgs, _, mu, logvar, h_code, c_code = netG(
                    noise, sent_emb, words_embs, mask, cnn_code,
                    region_features)

                real_img = imgs[cfg.TREE.BRANCH_NUM - 1]
                real_features = VGG(real_img)[0]
                fake_img = netDCM(h_code, real_features, sent_emb, words_embs,\
                                         mask, c_code)

                #######################################################
                # (3) Update D network
                ######################################################
                errD = 0
                D_logs = ''

                netD.zero_grad()
                errD = discriminator_loss(netD, imgs[cfg.TREE.BRANCH_NUM - 1],
                                          fake_img, sent_emb, real_labels,
                                          fake_labels, words_embs, cap_lens,
                                          image_encoder, class_ids,
                                          w_words_embs, wrong_caps_len,
                                          wrong_cls_id)
                errD.backward()
                optimizerD.step()
                D_logs = 'errD: %.2f ' % (errD)

                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G
                step += 1
                gen_iterations += 1

                netDCM.zero_grad()
                errC_total, C_logs = \
                    DCM_generator_loss(netD, image_encoder, fake_img, real_labels,
                                   words_embs, sent_emb, match_labels, cap_lens,\
                                    class_ids, VGG, real_img)

                errC_total.backward()
                optimizerC.step()
                for p, avg_p in zip(netDCM.parameters(), avg_param_C):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                if gen_iterations % 100 == 0:
                    print(D_logs + '\n' + C_logs)
                # save images
                if gen_iterations % 1000 == 0:
                    backup_para = copy_G_params(netDCM)
                    load_params(netDCM, avg_param_C)
                    self.save_img_results(netG,
                                          fixed_noise,
                                          sent_emb,
                                          words_embs,
                                          mask,
                                          image_encoder,
                                          captions,
                                          cap_lens,
                                          epoch,
                                          cnn_code,
                                          region_features,
                                          imgs,
                                          netDCM,
                                          real_features,
                                          name='average')
                    load_params(netDCM, backup_para)

            end_t = time.time()

            print('''[%d/%d][%d]
                  Loss_D: %.2f Loss_C: %.2f Time: %.2fs''' %
                  (epoch, self.max_epoch, self.num_batches, errD, errC_total,
                   end_t - start_t))

            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:
                self.save_model(netDCM, avg_param_C, netD, epoch)

        self.save_model(netDCM, avg_param_C, netD, self.max_epoch)
示例#4
0
    def train(self):
        text_encoder, image_encoder, netG, netsD, start_epoch, VGG = self.build_models(
        )
        avg_param_G = copy_G_params(netG)
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        gen_iterations = 0
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:

                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                data = data_iter.next()
                imgs, captions, cap_lens, class_ids, keys, wrong_caps, \
                                wrong_caps_len, wrong_cls_id = prepare_data(data)

                hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef

                # matched text embeddings
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()

                # mismatched text embeddings
                w_words_embs, w_sent_emb = text_encoder(
                    wrong_caps, wrong_caps_len, hidden)
                w_words_embs, w_sent_emb = w_words_embs.detach(
                ), w_sent_emb.detach()

                # image features: regional and global
                region_features, cnn_code = image_encoder(imgs[len(netsD) - 1])

                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                #######################################################
                # (2) Modify real images
                ######################################################
                noise.data.normal_(0, 1)
                fake_imgs, _, mu, logvar, _, _ = netG(noise, sent_emb, words_embs, mask, \
                                                    cnn_code, region_features)

                #######################################################
                # (3) Update D network
                ######################################################

                errD_total = 0
                D_logs = ''
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                              sent_emb, real_labels,
                                              fake_labels, words_embs,
                                              cap_lens, image_encoder,
                                              class_ids, w_words_embs,
                                              wrong_caps_len, wrong_cls_id)
                    # backward and update parameters
                    errD.backward(retain_graph=True)
                    optimizersD[i].step()
                    errD_total += errD
                    D_logs += 'errD%d: %.2f ' % (i, errD)

                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G
                step += 1
                gen_iterations += 1

                netG.zero_grad()
                errG_total, G_logs = \
                    generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                                   words_embs, sent_emb, match_labels, cap_lens,\
                                    class_ids, VGG, imgs)
                kl_loss = KL_loss(mu, logvar)
                errG_total += kl_loss
                G_logs += 'kl_loss: %.2f ' % kl_loss
                # backward and update parameters
                errG_total.backward()
                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                if gen_iterations % 100 == 0:
                    print(D_logs + '\n' + G_logs)
                # save images
                if gen_iterations % 1000 == 0:
                    backup_para = copy_G_params(netG)
                    load_params(netG, avg_param_G)
                    self.save_img_results(netG,
                                          fixed_noise,
                                          sent_emb,
                                          words_embs,
                                          mask,
                                          image_encoder,
                                          captions,
                                          cap_lens,
                                          epoch,
                                          cnn_code,
                                          region_features,
                                          imgs,
                                          name='average')
                    load_params(netG, backup_para)

            end_t = time.time()

            print('''[%d/%d][%d]
                  Loss_D: %.2f Loss_G: %.2f Time: %.2fs''' %
                  (epoch, self.max_epoch, self.num_batches, errD_total,
                   errG_total, end_t - start_t))

            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:
                self.save_model(netG, avg_param_G, netsD, epoch)

        self.save_model(netG, avg_param_G, netsD, self.max_epoch)
示例#5
0
    def train(self):
        text_encoder, image_encoder, netG, netsD, zsl_discriminator, start_epoch = self.build_models(
        )
        avg_param_G = copy_G_params(netG)
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        gen_iterations = 0
        # gen_iterations = start_epoch * self.num_batches
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            data_iter = iter(self.data_loader)
            it = tqdm.tqdm(range(
                self.num_batches)) if tqdm is not None else range(
                    self.num_batches)
            for step in it:
                # reset requires_grad to be trainable for all Ds
                # self.set_requires_grad_value(netsD, True)

                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                data = data_iter.next()
                imgs, captions, cap_lens, class_ids, keys = prepare_data(data)

                hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                #######################################################
                # (2) Generate fake images
                ######################################################
                fake_imgs, _ = netG(sent_emb, words_embs, mask)

                #######################################################
                # (3) Update D network
                ######################################################
                errD_total = 0
                D_logs = ''
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                              sent_emb, real_labels,
                                              fake_labels)
                    # backward and update parameters
                    errD.backward()
                    optimizersD[i].step()
                    errD_total += errD
                    D_logs += 'errD%d: %.2f ' % (i, errD.item())
                    writer.add_scalar(f'd/errD/{i}', errD.item(),
                                      gen_iterations)

                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G

                # do not need to compute gradient for Ds
                # self.set_requires_grad_value(netsD, False)
                netG.zero_grad()
                errG_total, G_logs = generator_loss(
                    netsD,
                    image_encoder,
                    zsl_discriminator,
                    fake_imgs,
                    real_labels,
                    words_embs,
                    sent_emb,
                    match_labels,
                    cap_lens,
                    class_ids,
                    writer=writer,
                    global_step=gen_iterations,
                )
                # backward and update parameters
                errG_total.backward()
                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                if gen_iterations % 100 == 0:
                    LOGGER.info(f'{D_logs}\n{G_logs}')
                # save images
                if gen_iterations % 1000 == 0:
                    backup_para = copy_G_params(netG)
                    load_params(netG, avg_param_G)
                    self.save_img_results(netG,
                                          sent_emb,
                                          words_embs,
                                          mask,
                                          image_encoder,
                                          captions,
                                          cap_lens,
                                          epoch,
                                          name='average')
                    load_params(netG, backup_para)
                    #
                    # self.save_img_results(netG, fixed_noise, sent_emb,
                    #                       words_embs, mask, image_encoder,
                    #                       captions, cap_lens,
                    #                       epoch, name='current')

                gen_iterations += 1

            end_t = time.time()
            info = (f'[{epoch}/{self.max_epoch}][{self.num_batches}] '
                    f'Loss_D: {errD_total.item():.2f} '
                    f'Loss_G: {errG_total.item():.2f} '
                    f'Time: {end_t - start_t:.2f}')
            LOGGER.info(info)

            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:  # and epoch != 0:
                self.save_model(netG, avg_param_G, netsD, epoch)

        self.save_model(netG, avg_param_G, netsD, self.max_epoch)