Example #1
0
    def save_img_results(self,
                         real_img,
                         netG,
                         noise,
                         sent_emb,
                         words_embs,
                         mask,
                         image_encoder,
                         captions,
                         cap_lens,
                         gen_iterations,
                         transf_matrices_inv,
                         label_one_hot,
                         name='current',
                         num_visualize=8):

        qa_nums = (cap_lens > 0).sum(1)
        real_captions = captions
        captions, _ = make_fake_captions(qa_nums)  # fake caption.

        # Save images
        # fake_imgs, attention_maps, _, _ = netG(noise, sent_emb, words_embs, mask)
        inputs = (noise, sent_emb, words_embs, mask, transf_matrices_inv,
                  label_one_hot)
        fake_imgs, attention_maps, _, _ = nn.parallel.data_parallel(
            netG, inputs, self.gpus)
        for i in range(len(attention_maps)):
            if len(fake_imgs) > 1:
                img = fake_imgs[i + 1].detach().cpu()
                lr_img = fake_imgs[i].detach().cpu()
            else:
                img = fake_imgs[0].detach().cpu()
                lr_img = None
            attn_maps = attention_maps[i]
            att_sze = attn_maps.size(2)
            img_set, _ = \
                build_super_images(img, captions, self.ixtoword,
                                   attn_maps, att_sze, lr_imgs=lr_img, nvis = num_visualize)
            if img_set is not None:
                im = Image.fromarray(img_set)
                fullpath = '%s/G_%s_%d_%d.png'\
                    % (self.image_dir, name, gen_iterations, i)
                im.save(fullpath)

        for i in range(cfg.TREE.BRANCH_NUM):
            save_pure_img_results(real_img[i].detach().cpu(),
                                  fake_imgs[i].detach().cpu(),
                                  gen_iterations,
                                  self.image_dir,
                                  token='level%d' % i)

        i = -1
        img = fake_imgs[i].detach()
        region_features, _ = image_encoder(img)
        att_sze = region_features.size(2)
        _, _, att_maps = words_loss(region_features.detach(),
                                    words_embs.detach(), None, qa_nums, None,
                                    self.batch_size)
        img_set, _ = build_super_images(fake_imgs[i].detach().cpu(),
                                        captions,
                                        self.ixtoword,
                                        att_maps,
                                        att_sze,
                                        nvis=num_visualize)
        # FIXME currently the `render_attn_to_html` supports only the last level.
        # please implement multiple level rendering.
        html_doc = render_attn_to_html([
            real_img[i].detach().cpu(),
            fake_imgs[i].detach().cpu(),
        ],
                                       real_captions,
                                       self.ixtoword,
                                       att_maps,
                                       att_sze,
                                       None,
                                       info=['Real Images', 'Fake Images'])
        with open('%s/damsm_attn_%d.html' % (self.image_dir, gen_iterations),
                  'w') as html_f:
            html_f.write(str(html_doc))

        if img_set is not None:
            im = Image.fromarray(img_set)
            fullpath = '%s/D_%s_%d.png'\
                % (self.image_dir, name, gen_iterations)
            im.save(fullpath)
Example #2
0
def train(dataloader, cnn_model, rnn_model, batch_size,
          labels, optimizer, epoch, ixtoword, image_dir):
    global global_step
    cnn_model.train()
    rnn_model.train()
    s_total_loss0 = 0
    s_total_loss1 = 0
    w_total_loss0 = 0
    w_total_loss1 = 0
    count = (epoch + 1) * len(dataloader)
    start_time = time.time()
    for step, data in enumerate(dataloader, 0):
        global_step += 1
        # print('step', step)
        rnn_model.zero_grad()
        cnn_model.zero_grad()

        # imgs: b x 3 x nbasesize x nbasesize
        imgs, captions, cap_lens, \
            class_ids, _, _, _, keys, _ = prepare_data(data)
        
        class_ids = None # Oh. is this ok? FIXME

        # words_features: batch_size x nef x 17 x 17
        # sent_code: batch_size x nef
        words_features, sent_code = cnn_model(imgs[-1])
        # --> batch_size x nef x 17*17
        nef, att_sze = words_features.size(1), words_features.size(2)

        # num_caps = (cap_lens > 0).sum(1)
        per_qa_embs, avg_qa_embs, num_caps = Level1RNNEncodeMagic(captions, cap_lens, rnn_model)

        w_loss0, w_loss1, attn_maps = words_loss(words_features, per_qa_embs, labels,
                                                 num_caps, class_ids,
                                                 batch_size)
        w_total_loss0 += w_loss0.data
        w_total_loss1 += w_loss1.data
        loss = w_loss0 + w_loss1

        s_loss0, s_loss1 = \
            sent_loss(sent_code, avg_qa_embs, labels, class_ids, batch_size)
        loss += s_loss0 + s_loss1
        s_total_loss0 += s_loss0.data
        s_total_loss1 += s_loss1.data
        #
        loss.backward()
        #
        # `clip_grad_norm` helps prevent
        # the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm(rnn_model.parameters(),
                                      cfg.TRAIN.RNN_GRAD_CLIP)
        optimizer.step()

        if step % UPDATE_INTERVAL == 0:
            count = epoch * len(dataloader) + step

            s_cur_loss0 = s_total_loss0.item() / UPDATE_INTERVAL
            s_cur_loss1 = s_total_loss1.item() / UPDATE_INTERVAL

            w_cur_loss0 = w_total_loss0.item() / UPDATE_INTERVAL
            w_cur_loss1 = w_total_loss1.item() / UPDATE_INTERVAL

            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | '
                  's_loss {:5.2f} {:5.2f} | '
                  'w_loss {:5.2f} {:5.2f}'
                  .format(epoch, step, len(dataloader),
                          elapsed * 1000. / UPDATE_INTERVAL,
                          s_cur_loss0, s_cur_loss1,
                          w_cur_loss0, w_cur_loss1))
            s_total_loss0 = 0
            s_total_loss1 = 0
            w_total_loss0 = 0
            w_total_loss1 = 0
            start_time = time.time()
            # attention Maps
            
            def make_fake_captions(num_caps):
                caps = torch.zeros(batch_size, cfg.TEXT.MAX_QA_NUM, dtype = torch.int64)
                ref = torch.arange(0, cfg.TEXT.MAX_QA_NUM).view(1, -1).repeat(batch_size, 1).cuda()
                targ = num_caps.view(-1, 1).repeat(1, cfg.TEXT.MAX_QA_NUM)
                caps[ref < targ] = 1
                return caps, {1: 'DUMMY'}

            _captions, _ixtoword = make_fake_captions(num_caps)
            
            html_doc = render_attn_to_html(imgs[-1].cpu(), captions,
                                   ixtoword, attn_maps, att_sze)

            with open('%s/attn_step%d.html' % (image_dir, global_step), 'w') as html_f:
                html_f.write(str(html_doc))
                
            img_set, _ = \
                build_super_images(imgs[-1].cpu(), _captions,
                                   _ixtoword, attn_maps, att_sze)

            if img_set is not None:
                im = Image.fromarray(img_set)
                fullpath = '%s/attention_maps%d.png' % (image_dir, global_step)
                im.save(fullpath)
    return count