コード例 #1
0
    def train(self):
        text_encoder, image_encoder, netG, netsD, start_epoch, VGG = self.build_models(
        )
        avg_param_G = copy_G_params(netG)
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        gen_iterations = 0
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:

                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                data = data_iter.next()
                input_imgs_list, output_imgs_list, captions, cap_lens = prepare_data_LGIE(
                    data)

                hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef

                # matched text embeddings
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                # 不detach! 需要训练!
                # words_embs, sent_emb = words_embs.detach(), sent_emb.detach()

                # if not cfg.ANNO_PATH:
                #   # mismatched text embeddings
                #   w_words_embs, w_sent_emb = text_encoder(wrong_caps, wrong_caps_len, hidden)
                #   w_words_embs, w_sent_emb = w_words_embs.detach(), w_sent_emb.detach()

                # image features: regional and global
                region_features, cnn_code = image_encoder(
                    input_imgs_list[cfg.TREE.BRANCH_NUM - 1])

                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                #######################################################
                # (2) Modify real images
                ######################################################
                noise.data.normal_(0, 1)
                fake_imgs, _, mu, logvar, _, _ = netG(noise, sent_emb,
                                                      words_embs, mask,
                                                      cnn_code,
                                                      region_features)

                #######################################################
                # (3) Update D network
                ######################################################
                errD_total = 0
                if cfg.TRAIN.W_GAN:
                    D_logs = ''
                    for i in range(len(netsD)):
                        netsD[i].zero_grad()
                        errD = discriminator_loss(netsD[i], input_imgs_list[i],
                                                  fake_imgs[i], sent_emb,
                                                  real_labels, fake_labels)

                        # backward and update parameters
                        errD.backward(retain_graph=True)
                        optimizersD[i].step()
                        errD_total += errD
                        D_logs += 'errD%d: %.2f ' % (i, errD)

                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G
                step += 1
                gen_iterations += 1

                netG.zero_grad()
                errG_total, G_logs = generator_loss(netsD, image_encoder,
                                                    fake_imgs, real_labels,
                                                    words_embs, sent_emb, None,
                                                    None, None, VGG,
                                                    output_imgs_list)
                kl_loss = KL_loss(mu, logvar) * cfg.TRAIN.W_KL
                errG_total += kl_loss
                G_logs += 'kl_loss: %.2f ' % kl_loss
                # backward and update parameters
                errG_total.backward()
                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                if gen_iterations % 100 == 0:
                    if cfg.TRAIN.W_GAN:
                        print(D_logs + '\n' + G_logs)
                # save images
                if gen_iterations % 500 == 0:
                    backup_para = copy_G_params(netG)
                    load_params(netG, avg_param_G)
                    # self.save_img_results(netG, fixed_noise, sent_emb,
                    #                       words_embs, mask, image_encoder,
                    #                       captions, cap_lens, epoch, cnn_code,
                    #                       region_features, output_imgs_list, name='average')

                    # JWT_VIS
                    nvis = 5
                    input_img, output_img, fake_img = input_imgs_list[
                        -1], output_imgs_list[-1], fake_imgs[-1]
                    input_img, output_img, fake_img = self.tensor_to_numpy(
                        input_img), self.tensor_to_numpy(
                            output_img), self.tensor_to_numpy(fake_img)
                    # (b x h x w x c)
                    gap = 50
                    text_bg = np.zeros((gap, 256 * 3, 3))
                    res = np.zeros((1, 256 * 3, 3))
                    for vis_idx in range(nvis):
                        cur_input_img, cur_output_img, cur_fake_img = input_img[
                            vis_idx], output_img[vis_idx], fake_img[vis_idx]
                        row = np.concatenate(
                            [cur_input_img, cur_output_img, cur_fake_img],
                            1)  # (h, w * 3, 3)
                        row = np.concatenate([row, text_bg],
                                             0)  # (h+gap, w * 3, 3)

                        cur_cap = captions[vis_idx].data.cpu().numpy()
                        sentence = []
                        for cap_idx in range(len(cur_cap)):
                            if cur_cap[cap_idx] == 0:
                                break
                            word = self.ixtoword[cur_cap[cap_idx]].encode(
                                'ascii', 'ignore').decode('ascii')
                            sentence.append(word)
                        cv2.putText(row, ' '.join(sentence), (40, 256 + 10),
                                    cv2.FONT_HERSHEY_PLAIN, 1.2, (0, 0, 255),
                                    1)
                        res = np.concatenate([res, row], 0)

                    # finish and write image
                    cv2.imwrite(
                        os.path.join(self.image_dir,
                                     f'G_jwtvis_{gen_iterations}.png'), res)
                    load_params(netG, backup_para)

            end_t = time.time()

            print('''[%d/%d][%d]
                  Loss_D: %.2f Loss_G: %.2f Time: %.2fs''' %
                  (epoch, self.max_epoch, self.num_batches, errD_total,
                   errG_total, end_t - start_t))

            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:
                self.save_model(netG, avg_param_G, netsD, epoch, text_encoder,
                                image_encoder)

        self.save_model(netG, avg_param_G, netsD, self.max_epoch, text_encoder,
                        image_encoder)
コード例 #2
0
ファイル: trainer.py プロジェクト: zxs789/Obj-GAN
    def train(self):
        netG, netINSD, netGLBD, start_epoch = self.build_models()
        avg_param_G = copy_G_params(netG)

        batch_size = self.batch_size
        noise = Variable(
            torch.FloatTensor(batch_size, cfg.ROI.BOXES_NUM,
                              len(self.cats_index_dict) * 4))
        fixed_noise = Variable(
            torch.FloatTensor(batch_size, cfg.ROI.BOXES_NUM,
                              len(self.cats_index_dict) * 4).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
        gen_iterations = 0
        lr_rate = 1
        pcp_score = 0.
        # gen_iterations = start_epoch * self.num_batches
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()
            if epoch > 50 and lr_rate > cfg.TRAIN.GENERATOR_LR / 10.:
                lr_rate *= 0.98
            optimizerG, optimizerINSD, optimizerGLBD = self.define_optimizers(
                netG, netINSD, netGLBD, lr_rate)
            data_iter = iter(self.data_loader)
            step = 0

            while step < self.num_batches:
                #print('step: ', step)
                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                data = data_iter.next()
                imgs, pooled_hmaps, hmaps, bbox_maps_fwd, bbox_maps_bwd, bbox_fmaps, \
                    rois, fm_rois, num_rois, class_ids, keys = prepare_data(data)

                #######################################################
                # (2) Generate fake images
                ######################################################
                max_num_roi = int(torch.max(num_rois))
                noise.data.normal_(0, 1)
                fake_hmaps = netG(noise[:, :max_num_roi], bbox_maps_fwd,
                                  bbox_maps_bwd, bbox_fmaps)

                #######################################################
                # (3-1) Update INSD network
                ######################################################
                errINSD = 0
                netINSD.zero_grad()
                errINSD = ins_discriminator_loss(netINSD, hmaps, fake_hmaps,
                                                 bbox_maps_fwd)
                errINSD.backward()
                optimizerINSD.step()
                INSD_logs = 'errINSD: %.2f ' % (errINSD.item())

                #######################################################
                # (3-2) Update GLBD network
                ######################################################
                errGLBD = 0
                netGLBD.zero_grad()
                errGLBD = glb_discriminator_loss(netGLBD, pooled_hmaps,
                                                 fake_hmaps, bbox_maps_fwd)
                errGLBD.backward()
                optimizerGLBD.step()
                GLBD_logs = 'errGLBD: %.2f ' % (errGLBD.item())

                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G
                step += 1
                gen_iterations += 1

                netG.zero_grad()
                errG_total, G_logs, item_pcp_score = generator_loss(
                    netINSD, netGLBD, self.vgg_model, hmaps, fake_hmaps,
                    bbox_maps_fwd)
                pcp_score += item_pcp_score

                errG_total.backward()
                # `clip_grad_norm` helps prevent
                # the exploding gradient problem in RNNs / LSTMs.
                torch.nn.utils.clip_grad_norm_(netG.parameters(),
                                               cfg.TRAIN.RNN_GRAD_CLIP)
                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                if gen_iterations % self.print_interval == 0:
                    elapsed = time.time() - start_t
                    print(
                        '| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | '
                        .format(epoch, step, self.num_batches,
                                elapsed * 1000. / self.print_interval))
                    print(INSD_logs + '\n' + GLBD_logs + '\n' + G_logs)
                    start_t = time.time()

                # save images
                if gen_iterations % self.display_interval == 0:
                    backup_para = copy_G_params(netG)
                    load_params(netG, avg_param_G)
                    self.save_img_results(netG,
                                          fixed_noise[:, :max_num_roi],
                                          imgs,
                                          bbox_maps_fwd,
                                          bbox_maps_bwd,
                                          bbox_fmaps,
                                          hmaps,
                                          rois,
                                          num_rois,
                                          gen_iterations,
                                          name='average')
                    load_params(netG, backup_para)

            pcp_score /= float(self.num_batches)
            print('pcp_score: ', pcp_score)
            fullpath = '%s/scores_%d.txt' % (self.score_dir, epoch)
            with open(fullpath, 'w') as fp:
                fp.write('pcp_score %f' % (pcp_score))

            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:  # and epoch != 0:
                self.save_model(netG, avg_param_G, netINSD, netGLBD, epoch)

        self.save_model(netG, avg_param_G, netINSD, netGLBD, self.max_epoch)
コード例 #3
0
    def train(self):
        torch.autograd.set_detect_anomaly(True)

        text_encoder, image_encoder, netG, netsD, start_epoch = self.build_models()
        avg_param_G = copy_G_params(netG)
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        if cfg.TRAIN.OPTIMIZE_DATA_LOADING:
            batch_sizes = self.batch_size
            noise, local_noise, fixed_noise = [], [], []
            for batch_size in batch_sizes:
                noise.append(Variable(torch.FloatTensor(batch_size, cfg.GAN.GLOBAL_Z_DIM)).to(cfg.DEVICE))
                local_noise.append(Variable(torch.FloatTensor(batch_size, cfg.GAN.LOCAL_Z_DIM)).to(cfg.DEVICE))
                fixed_noise.append(Variable(torch.FloatTensor(batch_size, cfg.GAN.GLOBAL_Z_DIM).normal_(0, 1)).to(cfg.DEVICE))
        else:
            batch_size = self.batch_size[0]
            noise = Variable(torch.FloatTensor(batch_size, cfg.GAN.GLOBAL_Z_DIM)).to(cfg.DEVICE)
            local_noise = Variable(torch.FloatTensor(batch_size, cfg.GAN.LOCAL_Z_DIM)).to(cfg.DEVICE)
            fixed_noise = Variable(torch.FloatTensor(batch_size, cfg.GAN.GLOBAL_Z_DIM).normal_(0, 1)).to(cfg.DEVICE)

        for epoch in range(start_epoch, self.max_epoch):
            logger.info("Epoch nb: %s" % epoch)
            gen_iterations = 0
            if cfg.TRAIN.OPTIMIZE_DATA_LOADING:
                data_iter = []
                for _idx in range(len(self.data_loader)):
                    data_iter.append(iter(self.data_loader[_idx]))
                total_batches_left = sum([len(self.data_loader[i]) for i in range(len(self.data_loader))])
                current_probability = [len(self.data_loader[i]) for i in range(len(self.data_loader))]
                current_probability_percent = [current_probability[i] / float(total_batches_left) for i in
                                               range(len(current_probability))]
            else:
                data_iter = iter(self.data_loader)

            _dataset = tqdm(range(self.num_batches))
            for step in _dataset:
                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                if cfg.TRAIN.OPTIMIZE_DATA_LOADING:
                    subset_idx = np.random.choice(range(len(self.data_loader)), size=None,
                                                  p=current_probability_percent)
                    total_batches_left -= 1
                    if total_batches_left > 0:
                        current_probability[subset_idx] -= 1
                        current_probability_percent = [current_probability[i] / float(total_batches_left) for i in
                                                       range(len(current_probability))]

                    max_objects = subset_idx
                    data = data_iter[subset_idx].next()
                else:
                    data = data_iter.next()
                    max_objects = 3
                _dataset.set_description('Obj-{}'.format(max_objects))

                imgs, captions, cap_lens, class_ids, keys, transformation_matrices, label_one_hot = prepare_data(data)
                transf_matrices = transformation_matrices[0]
                transf_matrices_inv = transformation_matrices[1]

                with torch.no_grad():
                    if cfg.TRAIN.OPTIMIZE_DATA_LOADING:
                        hidden = text_encoder.init_hidden(batch_sizes[subset_idx])
                    else:
                        hidden = text_encoder.init_hidden(batch_size)
                    # words_embs: batch_size x nef x seq_len
                    # sent_emb: batch_size x nef
                    words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                    words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
                    mask = (captions == 0).bool()
                    num_words = words_embs.size(2)
                    if mask.size(1) > num_words:
                        mask = mask[:, :num_words]

                #######################################################
                # (2) Generate fake images
                ######################################################
                if cfg.TRAIN.OPTIMIZE_DATA_LOADING:
                    noise[subset_idx].data.normal_(0, 1)
                    local_noise[subset_idx].data.normal_(0, 1)
                    inputs = (noise[subset_idx], local_noise[subset_idx], sent_emb, words_embs, mask, transf_matrices,
                              transf_matrices_inv, label_one_hot, max_objects)
                else:
                    noise.data.normal_(0, 1)
                    local_noise.data.normal_(0, 1)
                    inputs = (noise, local_noise, sent_emb, words_embs, mask, transf_matrices, transf_matrices_inv,
                              label_one_hot, max_objects)

                inputs = tuple((inp.to(cfg.DEVICE) if isinstance(inp, torch.Tensor) else inp) for inp in inputs)
                fake_imgs, _, mu, logvar = netG(*inputs)

                #######################################################
                # (3) Update D network
                ######################################################
                # errD_total = 0
                D_logs = ''
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    if cfg.TRAIN.OPTIMIZE_DATA_LOADING:
                        errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                                  sent_emb, real_labels[subset_idx], fake_labels[subset_idx],
                                                  local_labels=label_one_hot, transf_matrices=transf_matrices,
                                                  transf_matrices_inv=transf_matrices_inv, cfg=cfg,
                                                  max_objects=max_objects)
                    else:
                        errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                                  sent_emb, real_labels, fake_labels,
                                                  local_labels=label_one_hot, transf_matrices=transf_matrices,
                                                  transf_matrices_inv=transf_matrices_inv, cfg=cfg,
                                                  max_objects=max_objects)

                    # backward and update parameters
                    errD.backward()
                    optimizersD[i].step()
                    D_logs += 'errD%d: %.2f ' % (i, errD.item())

                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G
                # step += 1
                gen_iterations += 1

                # do not need to compute gradient for Ds
                netG.zero_grad()
                if cfg.TRAIN.OPTIMIZE_DATA_LOADING:
                    errG_total = \
                        generator_loss(netsD, image_encoder, fake_imgs, real_labels[subset_idx],
                                       words_embs, sent_emb, match_labels[subset_idx], cap_lens, class_ids,
                                       local_labels=label_one_hot, transf_matrices=transf_matrices,
                                       transf_matrices_inv=transf_matrices_inv, max_objects=max_objects)
                else:
                    errG_total = \
                        generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                                       words_embs, sent_emb, match_labels, cap_lens, class_ids,
                                       local_labels=label_one_hot, transf_matrices=transf_matrices,
                                       transf_matrices_inv=transf_matrices_inv, max_objects=max_objects)
                kl_loss = KL_loss(mu, logvar)
                errG_total += kl_loss
                # backward and update parameters
                errG_total.backward()
                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(p.data, alpha=0.001)

                if cfg.TRAIN.EMPTY_CACHE:
                    torch.cuda.empty_cache()

                # save images
                if (
                        2 * gen_iterations == self.num_batches
                        or 2 * gen_iterations + 1 == self.num_batches
                        or gen_iterations + 1 == self.num_batches
                ):
                    logger.info('Saving images...')
                    backup_para = copy_G_params(netG)
                    load_params(netG, avg_param_G)
                    if cfg.TRAIN.OPTIMIZE_DATA_LOADING:
                        self.save_img_results(netG, fixed_noise[subset_idx], sent_emb,
                                              words_embs, mask, image_encoder,
                                              captions, cap_lens, epoch, transf_matrices_inv,
                                              label_one_hot, local_noise[subset_idx], transf_matrices,
                                          max_objects, subset_idx, name='average')
                    else:
                        self.save_img_results(netG, fixed_noise, sent_emb,
                                          words_embs, mask, image_encoder,
                                          captions, cap_lens, epoch, transf_matrices_inv,
                                          label_one_hot, local_noise, transf_matrices,
                                          max_objects, None, name='average')
                    load_params(netG, backup_para)

            self.save_model(netG, avg_param_G, netsD, optimizerG, optimizersD, epoch)
        self.save_model(netG, avg_param_G, netsD, optimizerG, optimizersD, epoch)
コード例 #4
0
    def train(self):
        text_encoder, text_decoder, image_encoder, netG, netsD, start_epoch = self.build_models(
        )
        avg_param_G = copy_G_params(netG)
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        gen_iterations = 0
        # gen_iterations = start_epoch * self.num_batches
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:
                # reset requires_grad to be trainable for all Ds
                # self.set_requires_grad_value(netsD, True)

                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                data = data_iter.next()
                imgs, captions, cap_lens, class_ids, keys = prepare_data(data)

                hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)

                ##########SixingYu 20190918
                nef = sent_emb.shape[1]
                batch_size = sent_emb.shape[0]

                num_frame = 3
                frame_sent_emb = []
                de_hidden = text_decoder.initHidden(1, batch_size, nef)
                de_input = sent_emb
                for frame in range(num_frame):
                    de_output, de_hidden = text_decoder(
                        de_input.transpose(0, 1), de_hidden)
                    frame_sent_emb.append(de_output)
                    de_input = de_output
                # frame_sent_emb : 帧数 x batch_size x nef

                ##########SixingYu 20190917
                # decoder_input = Variable(torch.FloatTensor([[0]])).cuda()
                # decoder_hidden = sent_emb
                # for l in range(text_decoder.max_length):
                #     decoder_input, decoder_hidden, decoder_attention = text_decoder(decoder_input, decoder_hidden, words_embs)
                # ##########
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()

                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs,
                                                mask)

                #######################################################
                # (3) Update D network
                ######################################################
                errD_total = 0
                D_logs = ''
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                              sent_emb, real_labels,
                                              fake_labels)
                    # backward and update parameters
                    errD.backward()
                    optimizersD[i].step()
                    errD_total += errD
                    D_logs += 'errD%d: %.2f ' % (i, errD.data[0])

                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G
                step += 1
                gen_iterations += 1

                # do not need to compute gradient for Ds
                # self.set_requires_grad_value(netsD, False)
                netG.zero_grad()
                errG_total, G_logs = \
                    generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                                   words_embs, sent_emb, match_labels, cap_lens, class_ids)
                kl_loss = KL_loss(mu, logvar)
                errG_total += kl_loss
                G_logs += 'kl_loss: %.2f ' % kl_loss.data[0]
                # backward and update parameters
                errG_total.backward()
                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                if gen_iterations % 100 == 0:
                    print(D_logs + '\n' + G_logs)
                # save images
                if gen_iterations % 1000 == 0:
                    backup_para = copy_G_params(netG)
                    load_params(netG, avg_param_G)
                    self.save_img_results(netG,
                                          fixed_noise,
                                          sent_emb,
                                          words_embs,
                                          mask,
                                          image_encoder,
                                          captions,
                                          cap_lens,
                                          epoch,
                                          name='average')
                    load_params(netG, backup_para)
                    #
                    # self.save_img_results(netG, fixed_noise, sent_emb,
                    #                       words_embs, mask, image_encoder,
                    #                       captions, cap_lens,
                    #                       epoch, name='current')
            end_t = time.time()

            print('''[%d/%d][%d]
                  Loss_D: %.2f Loss_G: %.2f Time: %.2fs''' %
                  (epoch, self.max_epoch, self.num_batches, errD_total.data[0],
                   errG_total.data[0], end_t - start_t))

            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:  # and epoch != 0:
                self.save_model(netG, avg_param_G, netsD, epoch)

        self.save_model(netG, avg_param_G, netsD, self.max_epoch)
コード例 #5
0
    def train(self):
        text_encoder, image_encoder, caption_cnn, caption_rnn, netG, netsD, start_epoch = self.build_models()
        avg_param_G = copy_G_params(netG)
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)

        # netG = torch.nn.DataParallel(netG, device_ids=[0, 1])
        # 1 in batch size for real label
        # 0 in batch size for fake label
        # 0-batch size for math labels
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM  # random vector noise dimension
        noise = Variable(torch.FloatTensor(batch_size, nz))  # batch_size * noise size
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))  # same as before
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        gen_iterations = 0
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:
                # (1) Prepare training data and Compute text embeddings
                data = data_iter.next()
                # we already got the imgs, captions, cap_lens, class_ids, keys
                # what the prepare_data does is to send the data to CUDA and  sort the caption length
                imgs, captions, cap_lens, class_ids, keys = prepare_data(data)

                hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                # nef = num_hidden * num_directions
                # which means test_encoder sends captions to rnn and takes
                # all hidden output as word_embs and take the last hidden output as sent_emb
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
                # if no captions mask = True
                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                # (2) Generate fake images
                noise.data.normal_(0, 1)
                fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs,
                                                mask)  # get the mu, logvar from the CA augmentation, and fake image from the last layer of the generative net

                # (3) Update D network
                errD_total = 0
                D_logs = ''
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                              sent_emb, real_labels, fake_labels)
                    # backward and update parameters
                    errD.backward()
                    optimizersD[i].step()
                    errD_total += errD
                    D_logs += 'errD%d: %.6f ' % (i, errD.data)

                # (4) Update G network: maximize log(D(G(z)))
                # compute total loss for training G
                step += 1
                gen_iterations += 1
                netG.zero_grad()
                errG_total, G_logs = \
                    generator_loss(netsD, image_encoder, caption_cnn, caption_rnn, captions, fake_imgs, real_labels,
                                   words_embs, sent_emb, match_labels, cap_lens, class_ids)
                kl_loss = KL_loss(mu, logvar)
                errG_total += kl_loss
                G_logs += 'kl_loss: %.6f ' % kl_loss.data
                # backward and update parameters
                errG_total.backward()
                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                if gen_iterations % (self.num_batches/100 + 1) == 0:
                    print(D_logs + '\n' + G_logs)
                # save images
                # if gen_iterations % 1000 == 0:
                #     backup_para = copy_G_params(netG)
                #     load_params(netG, avg_param_G)
                #     self.save_img_results(netG, fixed_noise, sent_emb,
                #                           words_embs, mask, image_encoder,
                #                           captions, cap_lens, epoch, name='average')
                #     load_params(netG, backup_para)
            end_t = time.time()

            print('''Epoch [%d/%d][%d]
                  Loss_D: %.6f Loss_G: %.6f Time: %.6fs\n'''
                  % (epoch + 1, self.max_epoch, self.num_batches,
                     errD_total.data, errG_total.data,
                     end_t - start_t))

            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0 and epoch != 0:
                self.save_model(netG, avg_param_G, netsD, epoch)

        self.save_model(netG, avg_param_G, netsD, self.max_epoch)
コード例 #6
0
    def train(self):
        text_encoder, image_encoder, netG, target_netG, netsD, start_epoch, style_loss = self.build_models(
        )
        avg_param_G = copy_G_params(netG)
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        gen_iterations = 0

        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:

                data = data_iter.next()

                captions, cap_lens, imperfect_captions, imperfect_cap_lens, misc = data

                # Generate images for human-text ----------------------------------------------------------------
                data_human = [captions, cap_lens, misc]

                imgs, captions, cap_lens, class_ids, keys, wrong_caps, \
                                wrong_caps_len, wrong_cls_id = prepare_data(data_human)

                hidden = text_encoder.init_hidden(batch_size)
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()

                # wrong word and sentence embeddings
                w_words_embs, w_sent_emb = text_encoder(
                    wrong_caps, wrong_caps_len, hidden)
                w_words_embs, w_sent_emb = w_words_embs.detach(
                ), w_sent_emb.detach()

                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                noise.data.normal_(0, 1)
                fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs,
                                                mask)

                # Generate images for imperfect caption-text-------------------------------------------------------

                data_imperfect = [imperfect_captions, imperfect_cap_lens, misc]

                imgs, imperfect_captions, imperfect_cap_lens, i_class_ids, imperfect_keys, i_wrong_caps,\
                            i_wrong_caps_len, i_wrong_cls_id = prepare_data(data_imperfect)

                i_hidden = text_encoder.init_hidden(batch_size)
                i_words_embs, i_sent_emb = text_encoder(
                    imperfect_captions, imperfect_cap_lens, i_hidden)
                i_words_embs, i_sent_emb = i_words_embs.detach(
                ), i_sent_emb.detach()
                i_mask = (imperfect_captions == 0)
                i_num_words = i_words_embs.size(2)

                if i_mask.size(1) > i_num_words:
                    i_mask = i_mask[:, :i_num_words]

                # Move tensors to the secondary device.
                noise = noise.to(secondary_device
                                 )  # IMPORTANT! We are reusing the same noise.
                i_sent_emb = i_sent_emb.to(secondary_device)
                i_words_embs = i_words_embs.to(secondary_device)
                i_mask = i_mask.to(secondary_device)

                # Generate images.
                imperfect_fake_imgs, _, _, _ = target_netG(
                    noise, i_sent_emb, i_words_embs, i_mask)

                # Sort the results by keys to align ------------------------------------------------------------------------
                bag = [
                    sent_emb, real_labels, fake_labels, words_embs, class_ids,
                    w_words_embs, wrong_caps_len, wrong_cls_id
                ]

                keys, captions, cap_lens, fake_imgs, _, sorted_bag = sort_by_keys(keys, captions, cap_lens, fake_imgs,\
                                                                                  None, bag)

                sent_emb, real_labels, fake_labels, words_embs, class_ids, w_words_embs, wrong_caps_len, wrong_cls_id = \
                            sorted_bag

                imperfect_keys, imperfect_captions, imperfect_cap_lens, imperfect_fake_imgs, imgs, _ = \
                            sort_by_keys(imperfect_keys, imperfect_captions, imperfect_cap_lens, imperfect_fake_imgs, imgs,None)

                #-----------------------------------------------------------------------------------------------------------

                errD_total = 0
                D_logs = ''
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                              sent_emb, real_labels,
                                              fake_labels, words_embs,
                                              cap_lens, image_encoder,
                                              class_ids, w_words_embs,
                                              wrong_caps_len, wrong_cls_id)
                    # backward and update parameters
                    errD.backward(retain_graph=True)
                    optimizersD[i].step()
                    errD_total += errD
                    D_logs += 'errD%d: %.2f ' % (i, errD)

                step += 1
                gen_iterations += 1

                netG.zero_grad()
                errG_total, G_logs = \
                    generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                                   words_embs, sent_emb, match_labels, cap_lens, class_ids, style_loss, imgs)
                kl_loss = KL_loss(mu, logvar)

                errG_total += kl_loss

                G_logs += 'kl_loss: %.2f ' % kl_loss

                # Shift device for the imgs and target_imgs.-----------------------------------------------------
                for i in range(len(imgs)):
                    imgs[i] = imgs[i].to(secondary_device)
                    fake_imgs[i] = fake_imgs[i].to(secondary_device)

                # Compute and add ddva loss ---------------------------------------------------------------------
                neg_ddva = negative_ddva(imperfect_fake_imgs, imgs, fake_imgs)
                neg_ddva *= 10.  # Scale so that the ddva score is not overwhelmed by other losses.
                errG_total += neg_ddva.to(cfg.GPU_ID)
                G_logs += 'negative_ddva_loss: %.2f ' % neg_ddva
                #------------------------------------------------------------------------------------------------

                errG_total.backward()

                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                if gen_iterations % 100 == 0:
                    print(D_logs + '\n' + G_logs)

                # Copy parameters to the target network.
                if gen_iterations % 20 == 0:
                    load_params(target_netG, copy_G_params(netG))

            end_t = time.time()

            print('''[%d/%d][%d]
                  Loss_D: %.2f Loss_G: %.2f neg_ddva: %.2f Time: %.2fs''' %
                  (epoch, self.max_epoch, self.num_batches, errD_total,
                   errG_total, neg_ddva, end_t - start_t))

            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:
                self.save_model(netG, avg_param_G, netsD, epoch)

        self.save_model(netG, avg_param_G, netsD, self.max_epoch)
コード例 #7
0
ファイル: trainer.py プロジェクト: EloiseXu/posetrack
    def train(self, gen_iterations, start_t, epoch, lr):
        batch_time = AverageMeter()
        data_time = AverageMeter()

        self.netG.train()

        end = time.time()

        gt_win, pred_win = None, None
        bar = Bar('Train', max=len(self.train_loader))
        step = 0
        errD_total = None
        errG_total = None
        for i, (input, target, meta, mpii) in enumerate(self.train_loader):
            data_time.update(time.time() - end)

            ######################################################
            # (1) Prepare training data and Compute text embeddings
            ######################################################
            input, target = input.to(self.device), target.to(self.device,
                                                             non_blocking=True)
            target_weight = meta['target_weight'].to(self.device,
                                                     non_blocking=True)

            #######################################################
            # (2) Generate fake heatmaps
            ######################################################
            output = self.netG(input)

            #######################################################
            # (3) Update D network
            ######################################################
            errD_total = 0
            D_logs = ''
            for i in range(self.num_stacks):
                self.netsD[i].zero_grad()
                errD = discriminator_loss(self.netsD[i], target, target_weight,
                                          output[i], input, self.real_labels,
                                          self.fake_labels, mpii)

                errD.backword()
                self.optimizersD[i].step()
                errD_total += errD
                D_logs += 'errD%d: %d.2f ' % (i, errD.data[0])

            #######################################################
            # (4) Update G network: maximize log(D(G(z)))
            ######################################################
            step += 1
            gen_iterations += 1

            self.netG.zero_grad()
            errG_total, G_logs = \
                generator_loss(self.netsD, self.domainD, output, self.real_labels, input, target_weight, mpii)

            if self.debug:
                gt_batch_img = batch_with_heatmap(input, target)
                pred_batch_img = batch_with_heatmap(input, output)
                if not gt_win or not pred_win:
                    ax1 = plt.subplot(121)
                    ax1.title.set_text('Groundtruth')
                    gt_win = plt.imshow(gt_batch_img)
                    ax2 = plt.subplot(122)
                    ax2.title.set_text('Prediction')
                    pred_win = plt.imshow(pred_batch_img)
                else:
                    gt_win.set_data(gt_batch_img)
                    pred_win.set_data(pred_batch_img)
                plt.pause(.05)
                plt.draw()

            errG_total.backward()
            self.optimizerG.step()

            batch_time.update(time.time() - end)
            end = time.time()

            bar.suffix = '({batch}/{size}) Data: {data:.6f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:}'.format(
                batch=i + 1,
                size=len(self.train_loader),
                data=data_time.val,
                bt=batch_time.val,
                total=bar.elapsed_td,
                eta=bar.eta_td)
            bar.next()

            if gen_iterations % 100 == 0:
                print(D_logs + '\n' + G_logs)

        end_t = time.time()

        print('''[%d/%d]
                  Loss_D: %.2f Loss_G: %.2f Time: %.2fs''' %
              (epoch, self.epochs, errD_total.data[0], errG_total.data[0],
               end_t - start_t))

        if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:
            self.save_model(self.netsD, lr, epoch)

        bar.finish()
コード例 #8
0
ファイル: trainer.py プロジェクト: LeoXing1996/AttnGAN
    def train(self):
        wandb.init(name=cfg.EXP_NAME, project='AttnGAN', config=cfg, dir='../logs')

        text_encoder, image_encoder, netG, netsD, start_epoch = self.build_models()
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)
        text_encoder, image_encoder, netG, netsD, optimizerG, optimizersD =  \
            self.apply_apex(text_encoder, image_encoder, netG, netsD, optimizerG, optimizersD)
        # add watch
        wandb.watch(netG)
        for D in netsD:
            wandb.watch(D)

        avg_param_G = copy_G_params(netG)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        log_dict = {}
        gen_iterations = 0
        # gen_iterations = start_epoch * self.num_batches
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:
                # reset requires_grad to be trainable for all Ds
                # self.set_requires_grad_value(netsD, True)

                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                data = data_iter.next()
                imgs, captions, cap_lens, class_ids, keys = prepare_data(data)

                hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs, mask)

                #######################################################
                # (3) Update D network
                ######################################################
                errD_total = 0
                D_logs = ''
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                              sent_emb, real_labels, fake_labels)
                    # backward and update parameters
                    if cfg.APEX:
                        from apex import amp
                        with amp.scale_loss(errD, optimizersD[i], loss_id=i) as errD_scaled:
                            errD_scaled.backward()
                    else:
                        errD.backward()
                    optimizersD[i].step()
                    errD_total += errD
                    D_logs += 'errD%d: %.2f ' % (i, errD.item())
                    log_name = 'errD_{}'.format(i)
                    log_dict[log_name] = errD.item()

                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G
                step += 1
                gen_iterations += 1

                # do not need to compute gradient for Ds
                # self.set_requires_grad_value(netsD, False)
                netG.zero_grad()
                errG_total, G_logs, G_log_dict = \
                    generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                                   words_embs, sent_emb, match_labels, cap_lens, class_ids)
                kl_loss = KL_loss(mu, logvar)
                errG_total += kl_loss
                G_logs += 'kl_loss: %.2f ' % kl_loss.item()
                log_dict.update(G_log_dict)
                log_dict['kl_loss'] = kl_loss.item()
                # backward and update parameters
                if cfg.APEX:
                    from apex import amp
                    with amp.scale_loss(errG_total, optimizerG, loss_id=len(netsD)) as errG_scaled:
                        errG_scaled.backward()
                else:
                    errG_total.backward()
                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                wandb.log(log_dict)
                if gen_iterations % 100 == 0:
                    print(D_logs + '\n' + G_logs)
                    wandb.save('logs.ckpt')
                # save images
                if gen_iterations % 1000 == 0:
                    backup_para = copy_G_params(netG)
                    load_params(netG, avg_param_G)
                    self.save_img_results(netG, fixed_noise, sent_emb,
                                          words_embs, mask, image_encoder,
                                          captions, cap_lens, epoch, name='average')
                    load_params(netG, backup_para)
                    #
                    # self.save_img_results(netG, fixed_noise, sent_emb,
                    #                       words_embs, mask, image_encoder,
                    #                       captions, cap_lens,
                    #                       epoch, name='current')
            end_t = time.time()

            print('''[%d/%d][%d]
                  Loss_D: %.2f Loss_G: %.2f Time: %.2fs'''
                  % (epoch, self.max_epoch, self.num_batches,
                     errD_total.item(), errG_total.item(),
                     end_t - start_t))

            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:  # and epoch != 0:
                self.save_model(netG, avg_param_G, netsD, epoch)

        self.save_model(netG, avg_param_G, netsD, self.max_epoch)
コード例 #9
0
    def train(self):
        text_encoder, image_encoder, features_discriminator, netG, netsD, start_epoch = self.build_models()
        avg_param_G = copy_G_params(netG)
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM

        gen_iterations = 0
        # gen_iterations = start_epoch * self.num_batches
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            data_iter = iter(self.data_loader)
            for step in tqdm(range(self.num_batches)):
                # reset requires_grad to be trainable for all Ds
                # self.set_requires_grad_value(netsD, True)

                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                data = data_iter.next()
                imgs, captions, cap_lens, class_ids, keys = prepare_data(data)

                hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()  # TODO: Why detach?
                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                #######################################################
                # (2) Generate fake images
                ######################################################
                fake_imgs, _ = netG(sent_emb, words_embs, mask)

                #######################################################
                # (3) Update D network
                ######################################################
                errD_total = 0
                D_logs = ''
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                              sent_emb, real_labels, fake_labels)
                    # backward and update parameters
                    errD.backward()
                    optimizersD[i].step()
                    errD_total += errD
                    D_logs += 'errD%d: %.2f ' % (i, errD.item())

                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G
                gen_iterations += 1

                # do not need to compute gradient for Ds
                # self.set_requires_grad_value(netsD, False)
                netG.zero_grad()
                errG_total, G_logs = \
                    generator_loss(netsD, image_encoder, features_discriminator, fake_imgs, real_labels,
                                   words_embs, sent_emb, match_labels, cap_lens, class_ids)
                # backward and update parameters
                errG_total.backward()
                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                if gen_iterations % 100 == 1:
                    print(D_logs + '\n' + G_logs)
                # save images
                if gen_iterations % 1000 == 1:
                    backup_para = copy_G_params(netG)
                    load_params(netG, avg_param_G)
                    self.save_img_results(netG, sent_emb,
                                          words_embs, mask, image_encoder,
                                          captions, cap_lens, epoch, name='average')
                    load_params(netG, backup_para)
                    #
                    # self.save_img_results(netG, fixed_noise, sent_emb,
                    #                       words_embs, mask, image_encoder,
                    #                       captions, cap_lens,
                    #                       epoch, name='current')
            end_t = time.time()

            print('''[%d/%d][%d]
                  Loss_D: %.2f Loss_G: %.2f Time: %.2fs'''
                  % (epoch, self.max_epoch, self.num_batches,
                     errD_total.item(), errG_total.item(),
                     end_t - start_t))

            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:  # and epoch != 0:
                self.save_model(netG, avg_param_G, netsD, epoch)

        self.save_model(netG, avg_param_G, netsD, self.max_epoch)
コード例 #10
0
    def train(self):
        text_encoder, image_encoder, caption_cnn, caption_rnn, netG, netsD, start_epoch = self.build_models()
        avg_param_G = copy_G_params(netG)
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        gen_iterations = 0
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:
                # (1) Prepare training data and Compute text embeddings
                data = data_iter.next()
                imgs, captions, cap_lens, class_ids, keys = prepare_data(data)

                hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
                mask = (captions == 0) + (captions == 1) + (captions == 2)  # masked <start>, <end>, <pad>
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                # (2) Generate fake images
                noise.data.normal_(0, 1)
                fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs, mask)

                # (3) Update D network
                errD_total = 0
                D_logs = ''
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                              sent_emb, real_labels, fake_labels)
                    # backward and update parameters
                    errD.backward()
                    optimizersD[i].step()
                    errD_total += errD
                    D_logs += 'errD%d: %.2f ' % (i, errD.data.item())

                # (4) Update G network: maximize log(D(G(z)))
                # compute total loss for training G
                step += 1
                gen_iterations += 1
                netG.zero_grad()
                errG_total, G_logs = \
                    generator_loss(netsD, image_encoder, caption_cnn, caption_rnn, captions, fake_imgs, real_labels,
                                   words_embs, sent_emb, match_labels, cap_lens, class_ids)
                kl_loss = KL_loss(mu, logvar)
                errG_total += kl_loss
                G_logs += 'kl_loss: %.2f ' % kl_loss.data.item()
                # backward and update parameters
                errG_total.backward()
                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                if gen_iterations % 100 == 0:
                    print(D_logs + '\n' + G_logs)
                # save images
                if gen_iterations % 1000 == 0:
                    print('Saving images...')
                    backup_para = copy_G_params(netG)
                    load_params(netG, avg_param_G)
                    self.save_img_results(netG, fixed_noise, sent_emb,
                                          words_embs, mask, image_encoder,
                                          captions, cap_lens, epoch, name='average')
                    load_params(netG, backup_para)
            end_t = time.time()

            print('''[%d/%d][%d]
                  Loss_D: %.2f Loss_G: %.2f Time: %.2fs'''
                  % (epoch, self.max_epoch, self.num_batches,
                     errD_total.data.item(), errG_total.data.item(),
                     end_t - start_t))

            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:  # and epoch != 0:
                self.save_model(netG, avg_param_G, netsD, epoch)

        self.save_model(netG, avg_param_G, netsD, self.max_epoch)
コード例 #11
0
    def train(self):
        text_encoder, image_encoder, netG, netsD, start_epoch = self.build_models(
        )
        H_rnn_model, L_rnn_model = text_encoder
        avg_param_G = copy_G_params(netG)
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)

        if cfg.TRAIN.EVQAL.B_EVQAL:
            netVQA_E = load_resnet_image_encoder(model_stage=2)
            netVQA = load_vqa_net(cfg.TRAIN.EVQAL.NET,
                                  load_program_vocab(
                                      cfg.TRAIN.EVQAL.PROGRAM_VOCAB_FILE),
                                  feature_dim=(512, 28, 28))
        else:
            netVQA_E = netVQA = None

        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        gen_iterations = 0
        # gen_iterations = start_epoch * self.num_batches
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            am_vqa_loss = AverageMeter('VQA Loss')
            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:
                # reset requires_grad to be trainable for all Ds
                # self.set_requires_grad_value(netsD, True)

                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                data = data_iter.next()
                imgs, captions, cap_lens, class_ids, bbox, label_one_hot, transformation_matrices, keys, prog = self.prepare_data(
                    data)
                class_ids = None
                batch_size = captions.size(0)

                transf_matrices = transformation_matrices[0].detach()
                transf_matrices_inv = transformation_matrices[1].detach()

                per_qa_embs, avg_qa_embs, qa_nums =\
                    Level2RNNEncodeMagic(captions, cap_lens, L_rnn_model, H_rnn_model)
                per_qa_embs, avg_qa_embs = (per_qa_embs.detach(),
                                            avg_qa_embs.detach())

                _nmaxqa = cfg.TEXT.MAX_QA_NUM
                mask = torch.ones(batch_size, _nmaxqa,
                                  dtype=torch.uint8).cuda()
                _ref = torch.arange(0, _nmaxqa).view(1,
                                                     -1).repeat(batch_size,
                                                                1).cuda()
                _targ = qa_nums.view(-1, 1).repeat(1, _nmaxqa)
                mask[_ref < _targ] = 0
                num_words = per_qa_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                inputs = (noise, avg_qa_embs, per_qa_embs, mask,
                          transf_matrices_inv, label_one_hot)
                fake_imgs, _, mu, logvar = nn.parallel.data_parallel(
                    netG, inputs, self.gpus)

                #######################################################
                # (3) Update D network
                ######################################################
                errD_total = 0
                D_logs = ''
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    if i == 0:  # NOTE only the first level Discriminator is modified.
                        errD = discriminator_loss(
                            netsD[i],
                            imgs[i],
                            fake_imgs[i],
                            avg_qa_embs,
                            real_labels,
                            fake_labels,
                            self.gpus,
                            local_labels=label_one_hot,
                            transf_matrices=transf_matrices,
                            transf_matrices_inv=transf_matrices_inv)
                    else:
                        errD = discriminator_loss(netsD[i], imgs[i],
                                                  fake_imgs[i], avg_qa_embs,
                                                  real_labels, fake_labels,
                                                  self.gpus)

                    # backward and update parameters
                    errD.backward()
                    optimizersD[i].step()
                    errD_total += errD
                    D_logs += 'errD%d: %.2f ' % (i, errD.item())

                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G
                step += 1
                gen_iterations += 1

                # do not need to compute gradient for Ds
                # self.set_requires_grad_value(netsD, False)
                netG.zero_grad()
                errG_total, G_logs = \
                    generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                                   per_qa_embs, avg_qa_embs, match_labels, qa_nums, class_ids, self.gpus,
                                   local_labels=label_one_hot, transf_matrices=transf_matrices,
                                   transf_matrices_inv=transf_matrices_inv)

                if cfg.GAN.B_CA_NET:
                    kl_loss = KL_loss(mu, logvar)
                else:
                    kl_loss = torch.FloatTensor([0.]).squeeze().cuda()

                errG_total += kl_loss
                G_logs += 'kl_loss: %.2f ' % kl_loss.item()

                if cfg.TRAIN.EVQAL.B_EVQAL:
                    fake_img_fvqa = extract_image_feats(
                        fake_imgs[-1], netVQA_E, self.gpus)
                    errVQA = VQA_loss(netVQA, fake_img_fvqa, prog['programs'],
                                      prog['answers'], self.gpus)
                else:
                    errVQA = torch.FloatTensor([0.]).squeeze().cuda()
                G_logs += 'VQA_loss: %.2f ' % errVQA.data.item()
                beta = cfg.TRAIN.EVQAL.BETA
                errG_total += (errVQA * beta)

                # backward and update parameters
                errG_total.backward()
                optimizerG.step()

                am_vqa_loss.update(errVQA.cpu().item())
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                # save images
                if gen_iterations % 100 == 0:
                    print(D_logs + '\n' + G_logs)
                if gen_iterations % 500 == 0:  # FIXME original: 1000
                    backup_para = copy_G_params(netG)
                    load_params(netG, avg_param_G)
                    self.save_img_results(imgs,
                                          netG,
                                          fixed_noise,
                                          avg_qa_embs,
                                          per_qa_embs,
                                          mask,
                                          image_encoder,
                                          captions,
                                          cap_lens,
                                          epoch,
                                          transf_matrices_inv,
                                          label_one_hot,
                                          name='average')
                    load_params(netG, backup_para)
            end_t = time.time()

            print('''[%d/%d][%d]
                  Loss_D: %.2f Loss_G: %.2f Time: %.2fs''' %
                  (epoch, self.max_epoch, self.num_batches, errD_total.item(),
                   errG_total.item(), end_t - start_t))
            if cfg.TRAIN.EVQAL.B_EVQAL:
                print('Avg. VQA Loss of this epoch: %s' % str(am_vqa_loss))
            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:  # and epoch != 0:
                self.save_model(netG, avg_param_G, netsD, optimizerG,
                                optimizersD, epoch)

        self.save_model(netG, avg_param_G, netsD, optimizerG, optimizersD,
                        epoch)
コード例 #12
0
    def train(self):
        text_encoder, image_encoder, netG, target_netG, netsD, start_epoch = self.build_models()
        avg_param_G = copy_G_params(netG)
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        gen_iterations = 0
        sliding_window = []
        
        # gen_iterations = start_epoch * self.num_batches
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:
                # reset requires_grad to be trainable for all Ds
                # self.set_requires_grad_value(netsD, True)

                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                data = data_iter.next()
                
                captions, cap_lens, imperfect_captions, imperfect_cap_lens, misc = data
                
                # Generate images for human-text ----------------------------------------------------------------
                data_human = [captions, cap_lens, misc]
                
                imgs, captions, cap_lens, class_ids, keys = prepare_data(data_human)

                hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs, mask, cap_lens)

                # Generate images for imperfect caption-text-------------------------------------------------------

                data_imperfect = [imperfect_captions, imperfect_cap_lens, misc]
                    
                imgs, imperfect_captions, imperfect_cap_lens, i_class_ids, imperfect_keys = prepare_data(data_imperfect)
                
                i_hidden = text_encoder.init_hidden(batch_size)
                i_words_embs, i_sent_emb = text_encoder(imperfect_captions, imperfect_cap_lens, i_hidden)
                i_words_embs, i_sent_emb = i_words_embs.detach(), i_sent_emb.detach()
                i_mask = (imperfect_captions == 0)
                i_num_words = i_words_embs.size(2)
                
                if i_mask.size(1) > i_num_words:
                    i_mask = i_mask[:, :i_num_words]
                    
                # Move tensors to the secondary device.
                #noise  = noise.to(secondary_device) # IMPORTANT! We are reusing the same noise.
                #i_sent_emb = i_sent_emb.to(secondary_device)
                #i_words_embs = i_words_embs.to(secondary_device)
                #i_mask = i_mask.to(secondary_device)
                
                # Generate images.
                imperfect_fake_imgs, _, _, _ = target_netG(noise, i_sent_emb, i_words_embs, i_mask) 
                
                # Sort the results by keys to align ------------------------------------------------------------------------
                bag = [sent_emb, real_labels, fake_labels, words_embs, class_ids]
                
                keys, captions, cap_lens, fake_imgs, _, sorted_bag = sort_by_keys(keys, captions, cap_lens, fake_imgs,\
                                                                                  None, bag)
                    
                sent_emb, real_labels, fake_labels, words_embs, class_ids = \
                            sorted_bag
                 
                imperfect_keys, imperfect_captions, imperfect_cap_lens, imperfect_fake_imgs, imgs, _ = \
                            sort_by_keys(imperfect_keys, imperfect_captions, imperfect_cap_lens, imperfect_fake_imgs, imgs,None)
                    
                #-----------------------------------------------------------------------------------------------------------
 
                #######################################################
                # (3) Update D network
                ######################################################
                errD_total = 0
                D_logs = ''
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    errD, log = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                              sent_emb, real_labels, fake_labels)
                    # backward and update parameters
                    errD.backward()
                    optimizersD[i].step()
                    errD_total += errD
                    D_logs += 'errD%d: %.2f ' % (i, errD.item())
                    D_logs += log

                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G
                step += 1
                gen_iterations += 1

                # do not need to compute gradient for Ds
                # self.set_requires_grad_value(netsD, False)
                netG.zero_grad()
                errG_total, G_logs = \
                    generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                                   words_embs, sent_emb, match_labels, cap_lens, class_ids)
                kl_loss = KL_loss(mu, logvar)
                errG_total += kl_loss
                G_logs += 'kl_loss: %.2f ' % kl_loss.item()
                
                # Shift device -----------------------------------------------------
                #for i in range(len(imgs)):
                #    imgs[i] = imgs[i].to(secondary_device)
                #    fake_imgs[i] = fake_imgs[i].to(secondary_device)
                   
                print('Discriminator loss: ', errG_total)
                
                # Compute and add ddva loss ---------------------------------------------------------------------
                neg_ddva = negative_ddva(imperfect_fake_imgs, imgs, fake_imgs)
                neg_ddva *= 10. # Scale so that the ddva score is not overwhelmed by other losses.
                errG_total += neg_ddva.to(cfg.GPU_ID)
                #G_logs += 'negative_ddva_loss: %.2f ' % neg_ddva
                #------------------------------------------------------------------------------------------------
                
                # backward and update parameters
                errG_total.backward()
                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                if len(sliding_window)==100:
                	del sliding_window[0]
                sliding_window.append(neg_ddva)
                sliding_avg_ddva =  sum(sliding_window)/len(sliding_window)

                print('sliding_window avg NEG DDVA: ',sliding_avg_ddva)
                print('Negative ddva: ', neg_ddva)
                
                #if gen_iterations % 100 == 0:
                #    print('Epoch [{}/{}] Step [{}/{}]'.format(epoch, self.max_epoch, step,
                                                              self.num_batches) + ' ' + D_logs + ' ' + G_logs)
                
                # Copy parameters to the target network.
                #if gen_iterations % 4 == 0:
                load_params(target_netG, copy_G_params(netG))
                # Disable training in the target network:
                for p in target_netG.parameters():
                    p.requires_grad = False
                    
            end_t = time.time()

            #print('''[%d/%d] Loss_D: %.2f Loss_G: %.2f Time: %.2fs''' % (
            #    epoch, self.max_epoch, errD_total.item(), errG_total.item(), end_t - start_t))
            #print('-' * 89)
            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:  # and epoch != 0:
                self.save_model(netG, avg_param_G, netsD, epoch)
コード例 #13
0
    def train(self):
        text_encoder, image_encoder, netG, netsD, start_epoch = self.build_models(
        )
        avg_param_G = copy_G_params(netG)
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        gen_iterations = 0

        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:

                data = data_iter.next()

                captions, cap_lens, imperfect_captions, imperfect_cap_lens, misc = data

                # Generate images for human-text ----------------------------------------------------------------
                data_human = [captions, cap_lens, misc]

                imgs, captions, cap_lens, class_ids, keys, wrong_caps, \
                                wrong_caps_len, wrong_cls_id = prepare_data(data_human)

                hidden = text_encoder.init_hidden(batch_size)
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()

                # wrong word and sentence embeddings
                w_words_embs, w_sent_emb = text_encoder(
                    wrong_caps, wrong_caps_len, hidden)
                w_words_embs, w_sent_emb = w_words_embs.detach(
                ), w_sent_emb.detach()

                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                noise.data.normal_(0, 1)
                fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs,
                                                mask)

                # Generate images for imperfect caption-text-------------------------------------------------------

                data_imperfect = [imperfect_captions, imperfect_cap_lens, misc]

                imgs, imperfect_captions, imperfect_cap_lens, i_class_ids, imperfect_keys, i_wrong_caps,\
                            i_wrong_caps_len, i_wrong_cls_id = prepare_data(data_imperfect)

                i_hidden = text_encoder.init_hidden(batch_size)
                i_words_embs, i_sent_emb = text_encoder(
                    imperfect_captions, imperfect_cap_lens, i_hidden)
                i_words_embs, i_sent_emb = i_words_embs.detach(
                ), i_sent_emb.detach()
                i_mask = (imperfect_captions == 0)
                i_num_words = i_words_embs.size(2)

                if i_mask.size(1) > i_num_words:
                    i_mask = i_mask[:, :i_num_words]

                noise.data.normal_(0, 1)
                imperfect_fake_imgs, _, _, _ = netG(noise, i_sent_emb,
                                                    i_words_embs, i_mask)

                # Sort the results by keys to align ------------------------------------------------------------------------
                bag = [
                    sent_emb, real_labels, fake_labels, words_embs, class_ids,
                    w_words_embs, wrong_caps_len, wrong_cls_id
                ]

                keys, captions, cap_lens, fake_imgs, _, sorted_bag = sort_by_keys(keys, captions, cap_lens, fake_imgs,\
                                                                                  None, bag)

                sent_emb, real_labels, fake_labels, words_embs, class_ids, w_words_embs, wrong_caps_len, wrong_cls_id = \
                            sorted_bag

                imperfect_keys, imperfect_captions, imperfect_cap_lens, imperfect_fake_imgs, imgs, _ = \
                            sort_by_keys(imperfect_keys, imperfect_captions, imperfect_cap_lens, imperfect_fake_imgs, imgs,None)

                #-----------------------------------------------------------------------------------------------------------

                errD_total = 0
                D_logs = ''
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                              sent_emb, real_labels,
                                              fake_labels, words_embs,
                                              cap_lens, image_encoder,
                                              class_ids, w_words_embs,
                                              wrong_caps_len, wrong_cls_id)
                    # backward and update parameters
                    errD.backward(retain_graph=True)
                    optimizersD[i].step()
                    errD_total += errD
                    D_logs += 'errD%d: %.2f ' % (i, errD)

                step += 1
                gen_iterations += 1

                netG.zero_grad()
                errG_total, G_logs = \
                    generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                                   words_embs, sent_emb, match_labels, cap_lens, class_ids, imgs)
                kl_loss = KL_loss(mu, logvar)

                errG_total += kl_loss

                G_logs += 'kl_loss: %.2f ' % kl_loss

                # Compute and add ddva loss ---------------------------------------------------------------------
                neg_ddva = negative_ddva(imperfect_fake_imgs, imgs, fake_imgs)

                errG_total += neg_ddva

                G_logs += 'negative_ddva_loss: %.2f ' % neg_ddva

                errG_total.backward()

                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                if gen_iterations % 100 == 0:
                    print(D_logs + '\n' + G_logs)

                # save images
                #if gen_iterations % 1000 == 0:
                #    backup_para = copy_G_params(netG)
                #    load_params(netG, avg_param_G)
                #    self.save_img_results(netG, fixed_noise, sent_emb,
                #                          words_embs, mask, image_encoder,
                #                          captions, cap_lens, epoch, name='average')
                #    load_params(netG, backup_para)
                break

            end_t = time.time()

            print('''[%d/%d][%d]
                  Loss_D: %.2f Loss_G: %.2f Time: %.2fs''' %
                  (epoch, self.max_epoch, self.num_batches, errD_total,
                   errG_total, end_t - start_t))
コード例 #14
0
    def train(self):
        text_encoder, netG, netD, start_epoch, VGG = self.build_models()
        avg_param_G = copy_G_params(netG)
        optimizerG, optimizerD = self.define_optimizers(netG, netD)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        gen_iterations = 0
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:

                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                data = data_iter.next()
                imgs, w_imgs, captions, cap_lens, class_ids, keys, wrong_caps, \
                                wrong_caps_len, wrong_cls_id, sorted_cap_indices, w_sorted_cap_indices = prepare_data(data)

                hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef

                # matched text embeddings
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()

                # mismatched text embeddings
                w_words_embs, w_sent_emb = text_encoder(
                    wrong_caps, wrong_caps_len, hidden)
                w_words_embs, w_sent_emb = w_words_embs.detach(
                ), w_sent_emb.detach()
                ### arenge w_words_embs asn w_sent_emb
                w_words_embs = self.reverse_indices(w_words_embs,
                                                    sorted_cap_indices,
                                                    w_sorted_cap_indices)
                w_sent_emb = self.reverse_indices(w_sent_emb,
                                                  sorted_cap_indices,
                                                  w_sorted_cap_indices)
                wrong_caps = self.reverse_indices(wrong_caps,
                                                  sorted_cap_indices,
                                                  w_sorted_cap_indices)
                wrong_caps_len = self.reverse_indices(wrong_caps_len,
                                                      sorted_cap_indices,
                                                      w_sorted_cap_indices)
                wrong_cls_id = self.reverse_indices(wrong_cls_id,
                                                    sorted_cap_indices,
                                                    w_sorted_cap_indices)
                # image features: regional and global
                #region_features, cnn_code = image_encoder(imgs[len(netsD)-1])

                mask = (captions == 0)  ##
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                #######################################################
                # (2) Modify real images
                ######################################################
                noise.data.normal_(0, 1)
                enc_features = VGG(imgs[-1])
                fake_img, mu, logvar = nn.parallel.data_parallel(
                    netG, (imgs[-1], sent_emb, words_embs, noise, mask,
                           enc_features), self.gpus)

                #######################################################
                # (3) Update D network
                ######################################################

                netD.zero_grad()
                errD, D_logs = discriminator_loss(netD, imgs[-1], fake_img,
                                                  sent_emb, w_sent_emb,
                                                  real_labels, fake_labels)
                errD.backward()
                optimizerD.step()
                """
                errD_total = 0
                D_logs = ''
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                              sent_emb, real_labels, fake_labels,
                                              words_embs, cap_lens, image_encoder, class_ids, w_words_embs, 
                                              wrong_caps_len, wrong_cls_id)
                    # backward and update parameters
                    errD.backward(retain_graph=True)
                    optimizersD[i].step()
                    errD_total += errD
                    D_logs += 'errD%d: %.2f ' % (i, errD)
                """

                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G
                step += 1
                gen_iterations += 1

                netG.zero_grad()
                errG_total, G_logs = \
                    generator_loss(netD, fake_img,  imgs[-1], w_imgs[-1], real_labels, sent_emb, VGG, self.gpus)
                kl_loss = KL_loss(mu, logvar)
                errG_total += kl_loss
                G_logs += 'kl_loss: %.2f ' % kl_loss
                # backward and update parameters
                errG_total.backward()
                optimizerG.step()

                #self.save_img_results(netG, fixed_noise, w_sent_emb,
                #                          w_words_embs, captions, wrong_caps, epoch, imgs, mask, VGG)

                #self.save_img_results(netG, fixed_noise, w_sent_emb,
                #                          w_words_embs, captions, wrong_caps, epoch, imgs)
                #self.save_model(netG, avg_param_G, netD, epoch)

                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                if gen_iterations % 100 == 0:
                    print(D_logs + '\n' + G_logs)
                # save images
                if gen_iterations % 1000 == 0:
                    backup_para = copy_G_params(netG)
                    load_params(netG, avg_param_G)
                    self.save_img_results(netG, fixed_noise, sent_emb,
                                          words_embs, w_sent_emb, w_words_embs,
                                          captions, wrong_caps, epoch, imgs,
                                          mask, VGG)
                    load_params(netG, backup_para)

            end_t = time.time()

            print('''[%d/%d][%d]
                  Loss_D: %.2f Loss_G: %.2f Time: %.2fs''' %
                  (epoch, self.max_epoch, self.num_batches, errD, errG_total,
                   end_t - start_t))

            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:
                self.save_model(netG, avg_param_G, netD, epoch)

        self.save_model(netG, avg_param_G, netD, self.max_epoch)
コード例 #15
0
    def train(self):
        text_encoder, image_encoder, netG, netsD, start_epoch, style_loss = self.build_models(
        )
        avg_param_G = copy_G_params(netG)
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        gen_iterations = 0
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:

                data = data_iter.next()
                imgs, captions, cap_lens, class_ids, keys, wrong_caps, \
                                wrong_caps_len, wrong_cls_id = prepare_data(data)

                hidden = text_encoder.init_hidden(batch_size)
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()

                # wrong word and sentence embeddings
                w_words_embs, w_sent_emb = text_encoder(
                    wrong_caps, wrong_caps_len, hidden)
                w_words_embs, w_sent_emb = w_words_embs.detach(
                ), w_sent_emb.detach()

                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                noise.data.normal_(0, 1)
                fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs,
                                                mask)

                errD_total = 0
                D_logs = ''
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                              sent_emb, real_labels,
                                              fake_labels, words_embs,
                                              cap_lens, image_encoder,
                                              class_ids, w_words_embs,
                                              wrong_caps_len, wrong_cls_id)
                    # backward and update parameters
                    errD.backward(retain_graph=True)
                    optimizersD[i].step()
                    errD_total += errD
                    D_logs += 'errD%d: %.2f ' % (i, errD)

                step += 1
                gen_iterations += 1

                netG.zero_grad()
                errG_total, G_logs ,w_loss, s_loss = \
                    generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                                   words_embs, sent_emb, match_labels, cap_lens, class_ids, style_loss, imgs)
                kl_loss = KL_loss(mu, logvar)
                errG_total += kl_loss
                G_logs += 'kl_loss: %.2f ' % kl_loss

                errG_total.backward()
                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                if gen_iterations % 100 == 0:
                    print(D_logs + '\n' + G_logs)

                # save images
                if gen_iterations % 1000 == 0:
                    backup_para = copy_G_params(netG)
                    load_params(netG, avg_param_G)
                    self.save_img_results(netG,
                                          fixed_noise,
                                          sent_emb,
                                          words_embs,
                                          mask,
                                          image_encoder,
                                          captions,
                                          cap_lens,
                                          epoch,
                                          name='average')
                    load_params(netG, backup_para)

            end_t = time.time()

            print('''[%d/%d][%d]
                  Loss_D: %.2f Loss_G: %.2f Time: %.2fs''' %
                  (epoch, self.max_epoch, self.num_batches, errD_total,
                   errG_total, end_t - start_t))

            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:
                self.save_model(netG, avg_param_G, netsD, epoch)

        self.save_model(netG, avg_param_G, netsD, self.max_epoch)
コード例 #16
0
    def train(self):
        text_encoder, image_encoder, netG, netsD, start_epoch = self.build_models(
        )
        avg_param_G = copy_G_params(netG)
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        gen_iterations = 0
        # gen_iterations = start_epoch * self.num_batches

        errorD = []
        errorG = []
        loss_KL = []
        loss_s = []
        loss_w = []

        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:
                # reset requires_grad to be trainable for all Ds
                # self.set_requires_grad_value(netsD, True)

                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                data = data_iter.next()
                imgs, captions, cap_lens, class_ids, keys = prepare_data(data)

                hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs,
                                                mask)

                #######################################################
                # (3) Update D network
                ######################################################
                errD_total = 0
                D_logs = ''
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                              sent_emb, real_labels,
                                              fake_labels)
                    # backward and update parameters
                    errD.backward()
                    optimizersD[i].step()
                    errD_total += errD
                    D_logs += 'errD%d: %.2f ' % (i, errD.data)

                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G
                step += 1
                gen_iterations += 1

                # do not need to compute gradient for Ds
                # self.set_requires_grad_value(netsD, False)
                netG.zero_grad()
                errG_total, G_logs, w_loss, s_loss = \
                    generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                                   words_embs, sent_emb, match_labels, cap_lens, class_ids)
                kl_loss = KL_loss(mu, logvar)
                errG_total += kl_loss
                G_logs += 'kl_loss: %.2f ' % kl_loss.data
                # backward and update parameters
                errG_total.backward()
                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                if gen_iterations % 100 == 0:
                    print(D_logs + '\n' + G_logs)
                # save images
                if gen_iterations % 1000 == 0:
                    backup_para = copy_G_params(netG)
                    load_params(netG, avg_param_G)
                    self.save_img_results(netG,
                                          fixed_noise,
                                          sent_emb,
                                          words_embs,
                                          mask,
                                          image_encoder,
                                          captions,
                                          cap_lens,
                                          epoch,
                                          name='average')
                    load_params(netG, backup_para)
                    #
                    # self.save_img_results(netG, fixed_noise, sent_emb,
                    #                       words_embs, mask, image_encoder,
                    #                       captions, cap_lens,
                    #                       epoch, name='current')
            end_t = time.time()

            print('''[%d/%d][%d]
                  Loss_D: %.2f Loss_G: %.2f Time: %.2fs''' %
                  (epoch, self.max_epoch, self.num_batches, errD_total.data,
                   errG_total.data, end_t - start_t))

            errorD.append(errD_total)
            errorG.append(errG_total)
            loss_KL.append(kl_loss)
            loss_s.append(s_loss)
            loss_w.append(w_loss)

            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:  # and epoch != 0:
                self.save_model(netG, avg_param_G, netsD, epoch)

        plt.plot(errorG, label="Generator Loss")
        plt.plot(errorD, label="Discriminator Loss")
        plt.legend()
        plt.title("loss function for each epoch")
        plt.show()

        plt.plot(loss_KL, label="KL Loss")
        plt.title("KL loss function")
        plt.show()

        plt.plot(loss_s, label="sent Loss")
        plt.plot(loss_w, label="word Loss")
        plt.legend()
        plt.title("specfic loss function in generator")
        plt.show()

        self.save_model(netG, avg_param_G, netsD, self.max_epoch)
コード例 #17
0
    def train(self):
        text_encoder, image_encoder, netG, netsD, start_epoch = self.build_models()
        avg_param_G = copy_G_params(netG)
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        gen_iterations = 0
        # gen_iterations = start_epoch * self.num_batches
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            print(self.num_batches)
            while step < self.num_batches :
                # reset requires_grad to be trainable for all Ds
                # self.set_requires_grad_value(netsD, True)

                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                data = data_iter.next()
                imgs, captions, cap_lens, class_ids, keys = prepare_data(data)

                hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs, mask)
                if gen_iterations % 20 == 0 :
                    for fi in range(0, len(fake_imgs)):
                        img = fake_imgs[fi].detach().cpu()
                        writer.add_image('image/generated_sample_%d'%(fi), img[0], gen_iterations)
                #######################################################
                # (3) Update D network
                ######################################################
                errD_total = 0
                D_logs = ''
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                              sent_emb, real_labels, fake_labels)
                    if gen_iterations % 20 == 0 :
                        writer.add_scalar('data/d_loss_%d'%(i), errD.data.item(), gen_iterations)
                    # backward and update parameters
                    errD.backward()
                    optimizersD[i].step()
                    errD_total += errD
                    D_logs += 'errD%d: %.2f ' % (i, errD.data.item()) # LEE
                if gen_iterations % 20 == 0 :
                    writer.add_scalar('data/d_loss_total', errD_total.data.item(), gen_iterations)
                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G
                # step += 1 # LEE
                # gen_iterations += 1 # LEE

                # do not need to compute gradient for Ds
                # self.set_requires_grad_value(netsD, False)
                netG.zero_grad()
                errG_total, G_logs = \
                    generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                                   words_embs, sent_emb, match_labels, cap_lens, class_ids, gen_iterations, writer)
                kl_loss = KL_loss(mu, logvar)
                errG_total += kl_loss
                G_logs += 'kl_loss: %.2f ' % kl_loss.data.item() # LEE
                if gen_iterations % 20 == 0 :
                    writer.add_scalar('data/kl_loss', kl_loss.data.item(), gen_iterations)
                    writer.add_scalar('data/g_loss_total', errG_total.data.item(), gen_iterations)
                # backward and update parameters
                errG_total.backward()
                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                if gen_iterations % 1 == 0: # 100 == 0: LEE
                    print('[%d/%d]: '%(gen_iterations, self.num_batches) + D_logs + '\n' + G_logs)
                # save images
                # if gen_iterations % 10 == 0: # 1000 == 0: LEE
                    # backup_para = copy_G_params(netG)
                    # load_params(netG, avg_param_G)
                    # self.save_img_results(netG, fixed_noise, sent_emb,
                    #                       words_embs, mask, image_encoder,
                    #                       captions, cap_lens,
                    #                       gen_iterations, # epoch, LEE
                    #                       name='average')
                    # load_params(netG, backup_para)
                    #
                    # self.save_img_results(netG, fixed_noise, sent_emb,
                    #                       words_embs, mask, image_encoder,
                    #                       captions, cap_lens,
                    #                       epoch, name='current')

                step += 1
                gen_iterations += 1

            end_t = time.time()

            print('''[%d/%d][%d]
                  Loss_D: %.2f Loss_G: %.2f Time: %.2fs'''
                  % (epoch, self.max_epoch, self.num_batches,
                     errD_total.data.item(), errG_total.data.item(),
                     end_t - start_t)) # LEE

            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:  # and epoch != 0:
                self.save_model(netG, avg_param_G, netsD, image_encoder, text_encoder, epoch)

        self.save_model(netG, avg_param_G, netsD, image_encoder, text_encoder, self.max_epoch)
コード例 #18
0
ファイル: trainer.py プロジェクト: LeoXing1996/DM-GAN
    def train(self):
        text_encoder, image_encoder, netG, netsD, start_epoch = self.build_models()
        avg_param_G = copy_G_params(netG)
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        gen_iterations = 0
        loss_dict = {}
        # gen_iterations = start_epoch * self.num_batches
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:
                # reset requires_grad to be trainable for all Ds
                # self.set_requires_grad_value(netsD, True)

                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                data = data_iter.next()
                imgs, captions, cap_lens, class_ids, keys = prepare_data(data)

                hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs, mask, cap_lens)

                #######################################################
                # (3) Update D network
                ######################################################
                errD_total = 0
                D_logs = ''
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    errD, log, d_dict = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                                           sent_emb, real_labels, fake_labels)
                    # backward and update parameters
                    errD.backward()
                    optimizersD[i].step()
                    errD_total += errD
                    D_logs += 'errD%d: %.2f ' % (i, errD.item())
                    D_logs += log
                    loss_dict['Real_Acc_{}'.format(i)] = d_dict['Real_Acc']
                    loss_dict['Fake_Acc_{}'.format(i)] = d_dict['Fake_Acc']
                    loss_dict['errD_{}'.format(i)] = errD.item()

                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G
                step += 1
                gen_iterations += 1

                # do not need to compute gradient for Ds
                # self.set_requires_grad_value(netsD, False)
                netG.zero_grad()
                errG_total, G_logs, g_dict = \
                    generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                                   words_embs, sent_emb, match_labels, cap_lens, class_ids)
                kl_loss = KL_loss(mu, logvar)
                errG_total += kl_loss
                G_logs += 'kl_loss: %.2f ' % kl_loss.item()
                loss_dict.update(g_dict)
                loss_dict['kl_loss'] = kl_loss.item()
                # backward and update parameters
                errG_total.backward()
                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                if gen_iterations % 100 == 0:
                    print('Epoch [{}/{}] Step [{}/{}]'.format(epoch, self.max_epoch, step,
                                                              self.num_batches) + ' ' + D_logs + ' ' + G_logs)
                if self.logger:
                    self.logger.log(loss_dict)
                # save images
                if gen_iterations % 10000 == 0:
                    backup_para = copy_G_params(netG)
                    load_params(netG, avg_param_G)
                    self.save_img_results(netG, fixed_noise, sent_emb, words_embs, mask, image_encoder,
                                          captions, cap_lens, gen_iterations)
                    load_params(netG, backup_para)
                    #
                    # self.save_img_results(netG, fixed_noise, sent_emb,
                    #                       words_embs, mask, image_encoder,
                    #                       captions, cap_lens,
                    #                       epoch, name='current')
                # if gen_iterations % 1000 == 0:
                #    time.sleep(30)
                # if gen_iterations % 10000 == 0:
                #    time.sleep(160)
            end_t = time.time()

            print('''[%d/%d] Loss_D: %.2f Loss_G: %.2f Time: %.2fs''' % (
                epoch, self.max_epoch, errD_total.item(), errG_total.item(), end_t - start_t))
            print('-' * 89)
            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:  # and epoch != 0:
                self.save_model(netG, avg_param_G, netsD, epoch)

        self.save_model(netG, avg_param_G, netsD, self.max_epoch)
コード例 #19
0
ファイル: trainer_s2.py プロジェクト: mshaikh2/MMRL
    def train(self):

        now = datetime.datetime.now(dateutil.tz.tzlocal())
        timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')

        tb_dir = '../tensorboard/{0}_{1}_{2}'.format(cfg.DATASET_NAME,
                                                     cfg.CONFIG_NAME,
                                                     timestamp)
        mkdir_p(tb_dir)
        tbw = SummaryWriter(log_dir=tb_dir)  # Tensorboard logging

        text_encoder, image_encoder, netG, netsD, start_epoch, cap_model = self.build_models(
        )
        labels = Variable(torch.LongTensor(range(
            self.batch_size)))  # used for matching loss

        text_encoder.train()
        image_encoder.train()
        for k, v in image_encoder.named_children(
        ):  # set the input layer1-5 not training and no grads.
            if k in frozen_list_image_encoder:
                v.training = False
                v.requires_grad_(False)
        netG.train()
        for i in range(len(netsD)):
            netsD[i].train()
        cap_model.train()

        avg_param_G = copy_G_params(netG)
        optimizerI, optimizerT, optimizerG , optimizersD , optimizerC , lr_schedulerC \
        , lr_schedulerI , lr_schedulerT = self.define_optimizers(image_encoder
                                                                , text_encoder
                                                                , netG
                                                                , netsD
                                                                , cap_model)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))

        cap_criterion = torch.nn.CrossEntropyLoss(
        )  # add caption criterion here
        if cfg.CUDA:
            labels = labels.cuda()
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
            cap_criterion = cap_criterion.cuda()  # add caption criterion here
        cap_criterion.train()

        gen_iterations = 0
        # gen_iterations = start_epoch * self.num_batches
        for epoch in range(start_epoch, self.max_epoch):

            ##### set everything to trainable ####
            text_encoder.train()
            image_encoder.train()
            netG.train()
            cap_model.train()
            for k, v in image_encoder.named_children():
                if k in frozen_list_image_encoder:
                    v.train(False)
            for i in range(len(netsD)):
                netsD[i].train()
            ##### set everything to trainable ####

            fi_w_total_loss0 = 0
            fi_w_total_loss1 = 0
            fi_s_total_loss0 = 0
            fi_s_total_loss1 = 0
            ft_w_total_loss0 = 0
            ft_w_total_loss1 = 0
            ft_s_total_loss0 = 0
            ft_s_total_loss1 = 0
            s_total_loss0 = 0
            s_total_loss1 = 0
            w_total_loss0 = 0
            w_total_loss1 = 0
            c_total_loss = 0

            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:
                print('step:{:6d}|{:3d}'.format(step, self.num_batches),
                      end='\r')
                # reset requires_grad to be trainable for all Ds
                # self.set_requires_grad_value(netsD, True)

                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                data = data_iter.next()
                # add images, image masks, captions, caption masks for catr model
                imgs, captions, cap_lens, class_ids, keys, cap_imgs, cap_img_masks, sentences, sent_masks = prepare_data(
                    data)

                ################## feedforward damsm model ##################
                image_encoder.zero_grad()  # image/text encoders zero_grad here
                text_encoder.zero_grad()

                words_features, sent_code = image_encoder(
                    cap_imgs
                )  # input catr images to image encoder, feedforward, Nx256x17x17
                #                 words_features, sent_code = image_encoder(imgs[-1]) # input image_encoder
                nef, att_sze = words_features.size(1), words_features.size(2)
                # hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                words_embs, sent_emb = text_encoder(
                    captions)  #, cap_lens, hidden)

                #### damsm losses
                w_loss0, w_loss1, attn_maps = words_loss(
                    words_features, words_embs, labels, cap_lens, class_ids,
                    batch_size)
                w_total_loss0 += w_loss0.data
                w_total_loss1 += w_loss1.data
                damsm_loss = w_loss0 + w_loss1

                s_loss0, s_loss1 = sent_loss(sent_code, sent_emb, labels,
                                             class_ids, batch_size)
                s_total_loss0 += s_loss0.data
                s_total_loss1 += s_loss1.data
                damsm_loss += s_loss0 + s_loss1

                #                 damsm_loss.backward()

                #                 words_features = words_features.detach()
                # real image real text matching loss graph cleared here
                # grad accumulated -> text_encoder
                #                  -> image_encoder
                #################################################################################

                ################## feedforward image encoder and caption model ##################
                #                 words_features, sent_code = image_encoder(cap_imgs)
                cap_model.zero_grad()  # caption model zero_grad here

                cap_preds = cap_model(
                    words_features, cap_img_masks, sentences[:, :-1],
                    sent_masks[:, :-1])  # caption model feedforward
                cap_loss = caption_loss(cap_criterion, cap_preds, sentences)
                c_total_loss += cap_loss.data
                #                 cap_loss.backward() # caption loss graph cleared,
                # grad accumulated -> cap_model -> image_encoder
                torch.nn.utils.clip_grad_norm_(cap_model.parameters(),
                                               config.clip_max_norm)
                #                 optimizerC.step() # update cap_model params
                #################################################################################

                ############ Prepare the input to Gan from the output of text_encoder ################
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()

                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]
                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs,
                                                mask)

                #                 f_img = np.asarray(fake_imgs[-1].permute((0,2,3,1)).detach().cpu())
                #                 print('fake_imgs.size():{0},fake_imgs.min():{1},fake_imgs.max():{2}'.format(fake_imgs[-1].size()
                #                                   ,fake_imgs[-1].min()
                #                                   ,fake_imgs[-1].max()))

                #                 print('f_img.shape:{0}'.format(f_img.shape))

                #######################################################
                # (3) Update D network
                ######################################################
                errD_total = 0
                D_logs = ''
                for i in range(len(netsD)):
                    print(i)
                    netsD[i].zero_grad()
                    errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                              sent_emb, real_labels,
                                              fake_labels)
                    # backward and update parameters
                    errD.backward()
                    optimizersD[i].step()
                    errD_total += errD
                    D_logs += 'errD%d: %.2f ' % (i, errD.data)

                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G
                step += 1
                gen_iterations += 1

                # do not need to compute gradient for Ds
                # self.set_requires_grad_value(netsD, False)
                netG.zero_grad()
                errG_total, G_logs = \
                    generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                                   words_embs, sent_emb, match_labels, cap_lens, class_ids)
                kl_loss = KL_loss(mu, logvar)
                errG_total += kl_loss
                G_logs += 'kl_loss: %.2f ' % kl_loss.data
                # backward and update parameters
                errG_total.backward()
                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                # 14 -- 2800 iterations=steps for 1 epoch
                if gen_iterations % 100 == 0:
                    print(D_logs + '\n' + G_logs)
                # save images
                if gen_iterations % 1000 == 0:
                    backup_para = copy_G_params(netG)
                    load_params(netG, avg_param_G)
                    self.save_img_results(netG,
                                          fixed_noise,
                                          sent_emb,
                                          words_embs,
                                          mask,
                                          image_encoder,
                                          captions,
                                          cap_lens,
                                          epoch,
                                          name='average')
                    load_params(netG, backup_para)
                    #
                    # self.save_img_results(netG, fixed_noise, sent_emb,
                    #                       words_embs, mask, image_encoder,
                    #                       captions, cap_lens,
                    #                       epoch, name='current')

                #### temporary check ####
#                 if step == 5:
#                     break
#                 print('fake_img shape:',fake_imgs[-1].size())

#                 # this is fine #### exists in GeneratorLoss
#                 fake_imgs[-1] = fake_imgs[-1].detach()
#                 ####### fake imge real text matching loss #################
#                 fi_word_features, fi_sent_code = image_encoder(fake_imgs[-1])
# #                 words_embs, sent_emb = text_encoder(captions) # to update the text

#                 fi_w_loss0, fi_w_loss1, fi_attn_maps = words_loss(fi_word_features, words_embs, labels,
#                                                  cap_lens, class_ids, batch_size)

#                 fi_w_total_loss0 += fi_w_loss0.data
#                 fi_w_total_loss1 += fi_w_loss1.data

#                 fi_damsm_loss = fi_w_loss0 + fi_w_loss1

#                 fi_s_loss0, fi_s_loss1 = sent_loss(fi_sent_code, sent_emb, labels, class_ids, batch_size)

#                 fi_s_total_loss0 += fi_s_loss0.data
#                 fi_s_total_loss1 += fi_s_loss1.data

#                 fi_damsm_loss += fi_s_loss0 + fi_s_loss1

#                 fi_damsm_loss.backward()

###### real image fake text matching loss ##############

                fake_preds = torch.argmax(cap_preds,
                                          axis=-1)  # capation predictions
                fake_captions = tokenizer.batch_decode(
                    fake_preds.tolist(),
                    skip_special_tokens=True)  # list of strings
                fake_outputs = retokenizer.batch_encode_plus(
                    fake_captions,
                    max_length=64,
                    padding='max_length',
                    add_special_tokens=False,
                    return_attention_mask=True,
                    return_token_type_ids=False,
                    truncation=True)
                fake_tokens = fake_outputs['input_ids']
                #                 fake_tkmask = fake_outputs['attention_mask']
                f_tokens = np.zeros((len(fake_tokens), 15), dtype=np.int64)
                f_cap_lens = []
                cnt = 0
                for i in fake_tokens:
                    temp = np.array([x for x in i if x != 27299 and x != 0])
                    num_words = len(temp)
                    if num_words <= 15:
                        f_tokens[cnt][:num_words] = temp
                    else:
                        ix = list(np.arange(num_words))  # 1, 2, 3,..., maxNum
                        np.random.shuffle(ix)
                        ix = ix[:15]
                        ix = np.sort(ix)
                        f_tokens[cnt] = temp[ix]
                        num_words = 15
                    f_cap_lens.append(num_words)
                    cnt += 1

                f_tokens = Variable(torch.tensor(f_tokens))
                f_cap_lens = Variable(torch.tensor(f_cap_lens))
                if cfg.CUDA:
                    f_tokens = f_tokens.cuda()
                    f_cap_lens = f_cap_lens.cuda()

                ft_words_emb, ft_sent_emb = text_encoder(
                    f_tokens)  # input text_encoder

                ft_w_loss0, ft_w_loss1, ft_attn_maps = words_loss(
                    words_features, ft_words_emb, labels, f_cap_lens,
                    class_ids, batch_size)

                ft_w_total_loss0 += ft_w_loss0.data
                ft_w_total_loss1 += ft_w_loss1.data

                ft_damsm_loss = ft_w_loss0 + ft_w_loss1

                ft_s_loss0, ft_s_loss1 = sent_loss(sent_code, ft_sent_emb,
                                                   labels, class_ids,
                                                   batch_size)

                ft_s_total_loss0 += ft_s_loss0.data
                ft_s_total_loss1 += ft_s_loss1.data

                ft_damsm_loss += ft_s_loss0 + ft_s_loss1

                #                 ft_damsm_loss.backward()

                total_multimodal_loss = damsm_loss + ft_damsm_loss + cap_loss
                total_multimodal_loss.backward()
                ## loss = 0.5*loss1 + 0.4*loss2 + ...
                ## loss.backward() -> accumulate grad value in parameters.grad

                ## loss1 = 0.5*loss1
                ## loss1.backward()

                torch.nn.utils.clip_grad_norm_(image_encoder.parameters(),
                                               cfg.TRAIN.RNN_GRAD_CLIP)

                optimizerI.step()

                torch.nn.utils.clip_grad_norm_(text_encoder.parameters(),
                                               cfg.TRAIN.RNN_GRAD_CLIP)
                optimizerT.step()

                optimizerC.step()  # update cap_model params

            lr_schedulerC.step()
            lr_schedulerI.step()
            lr_schedulerT.step()

            end_t = time.time()

            tbw.add_scalar('Loss_D', float(errD_total.item()), epoch)
            tbw.add_scalar('Loss_G', float(errG_total.item()), epoch)
            tbw.add_scalar('train_w_loss0', float(w_total_loss0.item()), epoch)
            tbw.add_scalar('train_s_loss0', float(s_total_loss0.item()), epoch)
            tbw.add_scalar('train_w_loss1', float(w_total_loss1.item()), epoch)
            tbw.add_scalar('train_s_loss1', float(s_total_loss1.item()), epoch)
            tbw.add_scalar('train_c_loss', float(c_total_loss.item()), epoch)

            print('''[%d/%d][%d]
                  Loss_D: %.2f Loss_G: %.2f Time: %.2fs''' %
                  (epoch, self.max_epoch, self.num_batches, errD_total.data,
                   errG_total.data, end_t - start_t))

            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:  # and epoch != 0:

                self.save_model(netG, avg_param_G, image_encoder, text_encoder,
                                netsD, epoch, cap_model, optimizerC,
                                optimizerI, optimizerT, lr_schedulerC,
                                lr_schedulerI, lr_schedulerT)

            v_s_cur_loss, v_w_cur_loss, v_c_cur_loss = self.evaluate(
                self.dataloader_val, image_encoder, text_encoder, cap_model,
                self.batch_size)
            print(
                'v_s_cur_loss:{:.5f}, v_w_cur_loss:{:.5f}, v_c_cur_loss:{:.5f}'
                .format(v_s_cur_loss, v_w_cur_loss, v_c_cur_loss))
            tbw.add_scalar('val_w_loss', float(v_w_cur_loss), epoch)
            tbw.add_scalar('val_s_loss', float(v_s_cur_loss), epoch)
            tbw.add_scalar('val_c_loss', float(v_c_cur_loss), epoch)

        self.save_model(netG, avg_param_G, image_encoder, text_encoder, netsD,
                        self.max_epoch, cap_model, optimizerC, optimizerI,
                        optimizerT, lr_schedulerC, lr_schedulerI,
                        lr_schedulerT)
コード例 #20
0
ファイル: trainer.py プロジェクト: liangy396/AttentionalGAN
    def eval(self):
        text_encoder, image_encoder, netG, netsD, start_epoch = self.build_models()
        avg_param_G = copy_G_params(netG)
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        gen_iterations = 0
        start_t = time.time()

        data_iter = iter(self.data_loader)
        step = 0
        while step < self.num_batches:
            # reset requires_grad to be trainable for all Ds
            # self.set_requires_grad_value(netsD, True)

            ######################################################
            # (1) Prepare training data and Compute text embeddings
            ######################################################
            data = data_iter.next()
            imgs, captions, cap_lens, class_ids, keys = prepare_data(data)

            hidden = text_encoder.init_hidden(batch_size)
            # words_embs: batch_size x nef x seq_len
            # sent_emb: batch_size x nef
            words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
            words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
            mask = (captions == 0)
            num_words = words_embs.size(2)
            if mask.size(1) > num_words:
                mask = mask[:, :num_words]

            #######################################################
            # (2) Generate fake images
            ######################################################
            noise.data.normal_(0, 1)
            fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs, mask)

            #######################################################
            # (3) Update D network
            ######################################################
            errD_total = 0
            D_logs = ''
            for i in range(len(netsD)):
                netsD[i].zero_grad()
                errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                          sent_emb, real_labels, fake_labels)
                errD_total += errD

	        D_logs += 'errD%d: %.2f ' % (i, errD.item())

            #######################################################
            # (4) Update G network: maximize log(D(G(z)))
            ######################################################
            # compute total loss for training G
            step += 1
            gen_iterations += 1

            netG.zero_grad()
            errG_total, G_logs = \
                generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                               words_embs, sent_emb, match_labels, cap_lens, class_ids)
            kl_loss = KL_loss(mu, logvar)
            errG_total += kl_loss
	    G_logs += 'kl_loss: %.2f ' % kl_loss.item()
            print(D_logs + '\n' + G_logs)

        end_t = time.time()

	print('''[%d/%d][%d]
              Loss_D: %.2f Loss_G: %.2f Time: %.2fs'''
              % (epoch, self.max_epoch, self.num_batches,
                 errD_total.item(), errG_total.item(),
                 end_t - start_t))
コード例 #21
0
    def train(self):
        writer = SummaryWriter('runs/architecture')
        text_encoder, image_encoder, netG, netsD, start_epoch = self.build_models()
        avg_param_G = copy_G_params(netG)
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        gen_iterations = 0
        # gen_iterations = start_epoch * self.num_batches
        for epoch in range(start_epoch, self.max_epoch):
            print("=================================START TRAINING========================================")
            print("++++++++++++++++++++++++++++++++++%d+++++++++++++++++++++++++++++++++++++++++++++++++++\n" % gen_iterations)
            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:
                # reset requires_grad to be trainable for all Ds
                # self.set_requires_grad_value(netsD, True)

                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                data = data_iter.next()
                imgs, captions, cap_lens, class_ids, keys = prepare_data(data)

                hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]
            
                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs, mask)
            
                #######################################################
                # (3) Update D network
                ######################################################
                errD_total = 0
                D_logs = ''
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                              sent_emb, real_labels, fake_labels)
                    
                    # backward and update parameters
                    errD.backward()
                    optimizersD[i].step()
                    errD_total += errD
                    D_logs += 'errD%d: %.2f ' % (i, errD.data)

                    writer.add_scalar('data/errD%d' % i, errD.data.item(), gen_iterations)

                
                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G
                step += 1
                gen_iterations += 1

                # do not need to compute gradient for Ds
                # self.set_requires_grad_value(netsD, False)
                netG.zero_grad()
                errG_total, G_logs = \
                    generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                                   words_embs, sent_emb, match_labels, cap_lens, class_ids)
                kl_loss = KL_loss(mu, logvar)
                errG_total += kl_loss
                # CHANGED G_logs += 'kl_loss: %.2f ' % kl_loss.data[0]
                G_logs += 'kl_loss: %.2f ' % kl_loss.data
                # backward and update parameters
                errG_total.backward()
                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                if gen_iterations % 100 == 0:
                    print(D_logs + '\n' + G_logs)
                # save images
                if gen_iterations % 1000 == 0:
                    backup_para = copy_G_params(netG)
                    load_params(netG, avg_param_G)
                    self.save_img_results(netG, fixed_noise, sent_emb,
                                          words_embs, mask, image_encoder,
                                          captions, cap_lens, epoch, name='average')
                    load_params(netG, backup_para)
                    #
                    # self.save_img_results(netG, fixed_noise, sent_emb,
                    #                       words_embs, mask, image_encoder,
                    #                       captions, cap_lens,
                    #                       epoch, name='current')
        
            end_t = time.time()
            writer.add_scalar('data/Loss_D', errD_total.data.item(), epoch)
            writer.add_scalar('data/Loss_G', errG_total.data.item(), epoch)
            print('''[%d/%d][%d]
                  Loss_D: %.2f Loss_G: %.2f Time: %.2fs'''
                  % (epoch, self.max_epoch, self.num_batches,
                    errD_total.data, errG_total.data,
                     # CHANGED errD_total.data[0], errG_total.data[0],
                     end_t - start_t))

            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:  # and epoch != 0:
                self.save_model(netG, avg_param_G, netsD, epoch)

        self.save_model(netG, avg_param_G, netsD, self.max_epoch)
        writer.export_scalars_to_json("./all_scalars.json")
        writer.close()
コード例 #22
0
ファイル: trainer.py プロジェクト: k-eak/e-AttnGAN
    def train(self):
        text_encoder, image_encoder, netG, netsD, start_epoch = self.build_models()
        avg_param_G = copy_G_params(netG)
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        gen_iterations = 0
        # gen_iterations = start_epoch * self.num_batches
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:
                batch_t_begin = time.time()
                # reset requires_grad to be trainable for all Ds
                # self.set_requires_grad_value(netsD, True)

                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                data = data_iter.next()
                imgs, captions, cap_lens, class_ids, color_ids,sleeve_ids,gender_ids, keys = prepare_data(data)

                hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs, mask)

                #######################################################
                # (3) Update D network
                ######################################################
                errD_total = 0
                D_logs = ''
                D_logs_cls = ''
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    imgs[i] = gaussian_to_input(imgs[i]) ## INSTANCE NOISE
                    errD, cls_D= discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                              sent_emb, real_labels, fake_labels, class_ids, color_ids, sleeve_ids, gender_ids)
                    # backward and update parameters
                    errD_both = errD + cls_D/3.
                    # backward and update parameters
                    errD_both.backward()
                    optimizersD[i].step()
                    errD_total += errD
                    errD_total += cls_D/3.0
                    D_logs += 'errD%d: %.2f ' % (i, errD.data)
                    D_logs_cls += 'clsD%d: %.2f ' % (i, cls_D.data)

                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G
                step += 1
                gen_iterations += 1

                # do not need to compute gradient for Ds
                # self.set_requires_grad_value(netsD, False)
                netG.zero_grad()
                errG_total, G_logs = \
                    generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                                   words_embs, sent_emb, match_labels, cap_lens, class_ids, color_ids,sleeve_ids, gender_ids,imgs)
                kl_loss = KL_loss(mu, logvar)
                errG_total += kl_loss
                G_logs += 'kl_loss: %.2f ' % kl_loss.data
                # backward and update parameters
                errG_total.backward()
                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                if gen_iterations % 100 == 0:
                    batch_t_end = time.time()
                    print('| epoch {:3d} | {:5d}/{:5d} batches | batch_timer: {:5.2f} | '
                          .format(epoch, step, self.num_batches,
                                  batch_t_end - batch_t_begin,))
                    print(D_logs + '\n' + D_logs_cls + '\n' + G_logs)
                # save images
                if gen_iterations % 1000 == 0:
                    backup_para = copy_G_params(netG)
                    load_params(netG, avg_param_G)
                    self.save_img_results(netG, fixed_noise, sent_emb,
                                          words_embs, mask, image_encoder,
                                          captions, cap_lens, epoch, name='average')
                    load_params(netG, backup_para)
                    #
                    # self.save_img_results(netG, fixed_noise, sent_emb,
                    #                       words_embs, mask, image_encoder,
                    #                       captions, cap_lens,
                    #                       epoch, name='current')
            end_t = time.time()

            print('''[%d/%d][%d]
                  Loss_D: %.2f Loss_G: %.2f Time: %.2fs'''
                  % (epoch, self.max_epoch, self.num_batches,
                     errD_total.data, errG_total.data,
                     end_t - start_t))

            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:  # and epoch != 0:
                self.save_model(netG, avg_param_G, netsD, epoch)

        self.save_model(netG, avg_param_G, netsD, self.max_epoch)
コード例 #23
0
    def train(self):
        netsG, netsD, inception_model, classifiers, start_epoch = self.build_models(
        )
        avg_params_G = []
        for i in range(len(netsG)):
            avg_params_G.append(copy_G_params(netsG[i]))
        optimizersG, optimizersD, optimizersC = self.define_optimizers(
            netsG, netsD, classifiers)
        real_labels, fake_labels = self.prepare_labels()
        writer = SummaryWriter(self.args.run_dir)

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        gen_iterations = 0
        # gen_iterations = start_epoch * self.num_batches
        for epoch in range(start_epoch, self.max_epoch):
            print("epoch: {}/{}".format(epoch, self.max_epoch))
            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:
                print("step:{}/{} {:.2f}%".format(
                    step, self.num_batches, step / self.num_batches * 100))
                """
                if(step%self.display_interval==0):
                    print("step:{}/{} {:.2f}%".format(step, self.num_batches, step/self.num_batches*100))
                """
                # reset requires_grad to be trainable for all Ds
                # self.set_requires_grad_value(netsD, True)

                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                data = data_iter.next()
                real_imgs, atts, image_atts, class_ids, keys = prepare_data(
                    data)

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)

                print("before netG")
                fake_imgs = []
                C_losses = None
                if not self.args.kl_loss:
                    if cfg.TREE.BRANCH_NUM > 0:
                        fake_img1, h_code1 = nn.parallel.data_parallel(
                            netsG[0], (noise, atts, image_atts), self.gpus)
                        fake_imgs.append(fake_img1)
                        if self.args.split == 'train':  ##for train:
                            att_embeddings, C_losses = classifier_loss(
                                classifiers, inception_model, real_imgs[0],
                                image_atts, C_losses)
                            _, C_losses = classifier_loss(
                                classifiers, inception_model, fake_img1,
                                image_atts, C_losses)
                        else:
                            att_embeddings, _ = classifier_loss(
                                classifiers, inception_model, fake_img1,
                                image_atts)

                    if cfg.TREE.BRANCH_NUM > 1:
                        fake_img2, h_code2 = nn.parallel.data_parallel(
                            netsG[1], (h_code1, att_embeddings), self.gpus)
                        fake_imgs.append(fake_img2)
                        if self.args.split == 'train':
                            att_embeddings, C_losses = classifier_loss(
                                classifiers, inception_model, real_imgs[1],
                                image_atts, C_losses)
                            _, C_losses = classifier_loss(
                                classifiers, inception_model, fake_img1,
                                image_atts, C_losses)
                        else:
                            att_embeddings, _ = classifier_loss(
                                classifiers, inception_model, fake_img1,
                                image_atts)

                    if cfg.TREE.BRANCH_NUM > 2:
                        fake_img3 = nn.parallel.data_parallel(
                            netsG[2], (h_code2, att_embeddings), self.gpus)
                        fake_imgs.append(fake_img3)
                print("end netG")
                """
                if not self.args.kl_loss:
                    fake_imgs, C_losses = nn.parallel.data_parallel( netG, (noise, atts, image_atts,
                                                 inception_model, classifiers, imgs), self.gpus)
                else:
                    fake_imgs, mu, logvar = netG(noise, atts, image_atts) ## model内の次元が合っていない可能性。
                """

                #######################################################
                # (3) Update D network
                ######################################################
                errD_total = 0
                D_logs = ''
                errD_dic = {}
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    errD = discriminator_loss(netsD[i], real_imgs[i],
                                              fake_imgs[i], atts, real_labels,
                                              fake_labels)
                    # backward and update parameters
                    errD.backward()
                    optimizersD[i].step()
                    errD_total += errD
                    D_logs += 'errD%d: %.2f ' % (i, errD.item())
                    errD_dic['D_%d' % i] = errD.item()

                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G
                step += 1
                gen_iterations += 1

                # do not need to compute gradient for Ds
                # self.set_requires_grad_value(netsD, False)
                for i in range(len(netsG)):
                    netsG[i].zero_grad()
                print("before c backward")
                errC_total = 0
                C_logs = ''
                for i in range(len(classifiers)):
                    classifiers[i].zero_grad()
                    C_losses[i].backward(retain_graph=True)
                    optimizersC[i].step()
                    errC_total += C_losses[i]
                C_logs += 'errC_total: %.2f ' % (errC_total.item())
                print("end c backward")
                """
                for i,param in enumerate(netsG[0].parameters()):
                    if i==0:
                        print(param.grad)
                """

                ##TODO netGにgradientが溜まっているかどうかを確認せよ。

                errG_total = 0
                errG_total, G_logs, errG_dic = \
                    generator_loss(netsD, fake_imgs, real_labels, atts, errG_total)
                if self.args.kl_loss:
                    kl_loss = KL_loss(mu, logvar)
                    errG_total += kl_loss
                    G_logs += 'kl_loss: %.2f ' % kl_loss.item()
                    writer.add_scalar('kl_loss', kl_loss.item(),
                                      epoch * self.num_batches + step)

                # backward and update parameters
                errG_total.backward()
                for i in range(len(optimizersG)):
                    optimizersG[i].step()
                for i in range(len(optimizersC)):
                    optimizersC[i].step()

                errD_dic.update(errG_dic)
                writer.add_scalars('training_losses', errD_dic,
                                   epoch * self.num_batches + step)
                """
                self.save_img_results(netsG, fixed_noise, atts, image_atts, inception_model, 
                             classifiers, real_imgs, epoch, name='average') ##for debug
                """

                for i in range(len(netsG)):
                    for p, avg_p in zip(netsG[i].parameters(),
                                        avg_params_G[i]):
                        avg_p.mul_(0.999).add_(0.001, p.data)

                if gen_iterations % 100 == 0:
                    print(D_logs + '\n' + G_logs + '\n' + C_logs)
                # save images
                if gen_iterations % 1000 == 0:
                    backup_paras = []
                    for i in range(len(netsG)):
                        backup_para = copy_G_params(netsG[i])
                        backup_paras.append(backup_para)
                        load_params(netsG[i], avg_params_G[i])
                    self.save_img_results(netsG,
                                          fixed_noise,
                                          atts,
                                          image_atts,
                                          inception_model,
                                          classifiers,
                                          imgs,
                                          epoch,
                                          name='average')
                    for i in raneg(len(netsG)):
                        load_params(netsG[i], backup_paras[i])
                    #
                    # self.save_img_results(netG, fixed_noise, sent_emb,
                    #                       words_embs, mask, image_encoder,
                    #                       captions, cap_lens,
                    #                       epoch, name='current')
            end_t = time.time()

            print('''[%d/%d][%d]
                  Loss_D: %.2f Loss_G: %.2f Loss_C: %.2f Time: %.2fs''' %
                  (epoch, self.max_epoch, self.num_batches, errD_total.item(),
                   errG_total.item(), errC_total.item(), end_t - start_t))

            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:  # and epoch != 0:
                self.save_model(netsG, avg_params_G, netsD, classifiers, epoch)

        self.save_model(netsG, avg_params_G, netsD, classifiers,
                        self.max_epoch)
コード例 #24
0
    def train(self):
        text_encoder, image_encoder, netG, netsD, start_epoch = self.build_models()
        avg_param_G = copy_G_params(netG)
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        gen_iterations = 0
        # gen_iterations = start_epoch * self.num_batches
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:
                # reset requires_grad to be trainable for all Ds
                # self.set_requires_grad_value(netsD, True)

                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                data = data_iter.next()
                imgs, captions, cap_lens, class_ids, keys, transformation_matrices, label_one_hot = prepare_data(data)
                transf_matrices = transformation_matrices[0]
                transf_matrices_inv = transformation_matrices[1]

                hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
                mask = (captions == 0)
                num_words = words_embs.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                inputs = (noise, sent_emb, words_embs, mask, transf_matrices_inv, label_one_hot)
                fake_imgs, _, mu, logvar = nn.parallel.data_parallel(netG, inputs, self.gpus)

                #######################################################
                # (3) Update D network
                ######################################################
                errD_total = 0
                D_logs = ''
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    if i == 0:
                        errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                                  sent_emb, real_labels, fake_labels, self.gpus,
                                                  local_labels=label_one_hot, transf_matrices=transf_matrices,
                                                  transf_matrices_inv=transf_matrices_inv)
                    else:
                        errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                                  sent_emb, real_labels, fake_labels, self.gpus)

                    # backward and update parameters
                    errD.backward()
                    optimizersD[i].step()
                    errD_total += errD
                    D_logs += 'errD%d: %.2f ' % (i, errD.item())

                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G
                step += 1
                gen_iterations += 1

                # do not need to compute gradient for Ds
                # self.set_requires_grad_value(netsD, False)
                netG.zero_grad()
                errG_total, G_logs = \
                    generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                                   words_embs, sent_emb, match_labels, cap_lens, class_ids, self.gpus,
                                   local_labels=label_one_hot, transf_matrices=transf_matrices,
                                   transf_matrices_inv=transf_matrices_inv)
                kl_loss = KL_loss(mu, logvar)
                errG_total += kl_loss
                G_logs += 'kl_loss: %.2f ' % kl_loss.item()
                # backward and update parameters
                errG_total.backward()
                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                # save images
                if gen_iterations % 1000 == 0:
                    print(D_logs + '\n' + G_logs)

                    backup_para = copy_G_params(netG)
                    load_params(netG, avg_param_G)
                    self.save_img_results(netG, fixed_noise, sent_emb,
                                          words_embs, mask, image_encoder,
                                          captions, cap_lens, epoch,  transf_matrices_inv,
                                          label_one_hot, name='average')
                    load_params(netG, backup_para)
            end_t = time.time()

            print('''[%d/%d][%d]
                  Loss_D: %.2f Loss_G: %.2f Time: %.2fs'''
                  % (epoch, self.max_epoch, self.num_batches,
                     errD_total.item(), errG_total.item(),
                     end_t - start_t))

            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:  # and epoch != 0:
                self.save_model(netG, avg_param_G, netsD, optimizerG, optimizersD, epoch)

        self.save_model(netG, avg_param_G, netsD, optimizerG, optimizersD, epoch)
コード例 #25
0
ファイル: trainer.py プロジェクト: smallflyingpig/AttnGAN
    def train(self):
        text_encoder, image_encoder, netG, netsD, start_epoch = self.build_models(
        )
        avg_param_G = copy_G_params(netG)
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = (torch.FloatTensor(batch_size, nz).requires_grad_())
        fixed_noise = (torch.FloatTensor(batch_size,
                                         nz).normal_(0, 1).requires_grad_())
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        gen_iterations = 0
        # gen_iterations = start_epoch * self.num_batches
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:
                # reset requires_grad to be trainable for all Ds
                # self.set_requires_grad_value(netsD, True)

                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                data = data_iter.next()
                imgs, captions, cap_lens, class_ids, keys = prepare_data(data)

                # hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                words_embs, sent_emb = text_encoder(captions, cap_lens)
                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
                if self.audio_flag:
                    mask = torch.ByteTensor([[0] * cap_len.item() + [1] *
                                             (32 - cap_len.item())
                                             for cap_len in cap_lens])
                    mask = mask.to(words_embs.device)
                else:
                    mask = (captions == 0)
                    num_words = words_embs.size(2)
                    if mask.size(1) > num_words:
                        mask = mask[:, :num_words]

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs,
                                                mask)

                #######################################################
                # (3) Update D network
                ######################################################
                errD_total = 0
                errD_seq = []
                D_logs = ''
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                              sent_emb, real_labels,
                                              fake_labels)
                    # backward and update parameters
                    errD.backward()
                    optimizersD[i].step()
                    errD_seq.append(errD)
                    errD_total += errD
                    D_logs += 'errD%d: %.2f ' % (i, errD.item())

                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G
                step += 1
                gen_iterations += 1

                # do not need to compute gradient for Ds
                # self.set_requires_grad_value(netsD, False)
                netG.zero_grad()
                errG_total, G_logs, lossG_seq = \
                    generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                                   words_embs, sent_emb, match_labels, cap_lens, class_ids)
                kl_loss = KL_loss(mu, logvar)
                errG_total += kl_loss
                G_logs += 'kl_loss: %.2f ' % kl_loss.item()
                # backward and update parameters
                errG_total.backward()
                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                self.writer.add_scalars(main_tag="loss_d",
                                        tag_scalar_dict={"loss_d": errD_total},
                                        global_step=gen_iterations)
                for idx in range(len(netsD)):
                    self.writer.add_scalars(main_tag="loss_d",
                                            tag_scalar_dict={
                                                "loss_d_{:d}".format(idx):
                                                errD_seq[idx]
                                            },
                                            global_step=gen_iterations)
                self.writer.add_scalars(main_tag="loss_g",
                                        tag_scalar_dict={
                                            "loss_g": errG_total,
                                            "loss_g_gan": lossG_seq[0],
                                            "loss_g_w": lossG_seq[1],
                                            "loss_g_s": lossG_seq[2],
                                            "kl_loss": kl_loss
                                        },
                                        global_step=gen_iterations)
                if gen_iterations % 100 == 0:
                    self.logger.info(D_logs + '\n' + G_logs)
                # save images
                if gen_iterations % 1000 == 0:
                    backup_para = copy_G_params(netG)
                    load_params(netG, avg_param_G)
                    self.save_img_results(netG,
                                          fixed_noise,
                                          sent_emb,
                                          words_embs,
                                          mask,
                                          image_encoder,
                                          captions,
                                          cap_lens,
                                          epoch,
                                          name='average')
                    load_params(netG, backup_para)
                    self.logger.info("save image result...")
                    #
                    # self.save_img_results(netG, fixed_noise, sent_emb,
                    #                       words_embs, mask, image_encoder,
                    #                       captions, cap_lens,
                    #                       epoch, name='current')
            end_t = time.time()

            self.logger.info('''[{}/{}][{}]
                Loss_D: {:.2f} Loss_G: {:.2f} Time: {:.2f}'''.format(
                epoch, self.max_epoch, self.num_batches, errD_total.item(),
                errG_total.item(), end_t - start_t))

            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:  # and epoch != 0:
                self.save_model(netG, avg_param_G, netsD, epoch)

        self.save_model(netG, avg_param_G, netsD, self.max_epoch)