Example #1
0
    def evaluate(self):
        '''Model evaluation'''

        self.log.add("| Calculating validation error")
        self.image_encoder.eval()
        self.text_encoder.eval()
        s_total_loss = 0
        w_total_loss = 0
        for step, data in enumerate(self.data_loader_val, 0):
            real_imgs, captions, cap_lens, class_ids, keys = prepare_data(data)

            words_features, sent_code = self.image_encoder(real_imgs[-1])
            # nef = words_features.size(1)
            # words_features = words_features.view(batch_size, nef, -1)

            hidden = self.text_encoder.init_hidden(self.batch_size)
            words_emb, sent_emb = self.text_encoder(captions, cap_lens, hidden)

            w_loss0, w_loss1, attn = words_loss(words_features, words_emb, self.labels,
                                                cap_lens, class_ids, self.batch_size)
            w_total_loss += (w_loss0 + w_loss1).data

            s_loss0, s_loss1 = sent_loss(sent_code, sent_emb, self.labels, class_ids, self.batch_size)
            s_total_loss += (s_loss0 + s_loss1).data

            if step == 50:
                break

        s_cur_loss = s_total_loss.item()/ step
        w_cur_loss = w_total_loss.item() / step

        return s_cur_loss, w_cur_loss
Example #2
0
    def save_img_results(self,
                         netG,
                         noise,
                         sent_emb,
                         words_embs,
                         mask,
                         image_encoder,
                         captions,
                         cap_lens,
                         gen_iterations,
                         name='current'):
        '''Save generator results'''

        self.log.add("| Saving {} result images ... ".format(name))

        fake_imgs, attention_maps, _, _ = netG(noise, sent_emb, words_embs,
                                               mask)
        for i in range(len(attention_maps)):
            if len(fake_imgs) > 1:
                img = fake_imgs[i + 1].detach().cpu()
                lr_img = fake_imgs[i].detach().cpu()
            else:
                img = fake_imgs[0].detach().cpu()
                lr_img = None
            attn_maps = attention_maps[i]
            att_sze = attn_maps.size(2)
            img_set, _ = build_super_images(img,
                                            captions,
                                            self.ixtoword,
                                            attn_maps,
                                            att_sze,
                                            lr_imgs=lr_img)
            if img_set is not None:
                im = Image.fromarray(img_set)
                fullpath = '%s/G_%s_%d_%d.png' % (self.image_dir, name,
                                                  gen_iterations, i)
                im.save(fullpath)

        i = -1
        img = fake_imgs[i].detach()
        region_features, _ = image_encoder(img)
        att_sze = region_features.size(2)
        _, _, att_maps = words_loss(region_features.detach(),
                                    words_embs.detach(), None, cap_lens, None,
                                    self.batch_size)
        img_set, _ = build_super_images(fake_imgs[i].detach().cpu(), captions,
                                        self.ixtoword, att_maps, att_sze)
        if img_set is not None:
            im = Image.fromarray(img_set)
            fullpath = '%s/D_%s_%d.png' % (self.image_dir, name,
                                           gen_iterations)
            im.save(fullpath)
Example #3
0
    def train_step(self, epoch):

        self.image_encoder.train()
        self.text_encoder.train()
        s_total_loss0 = 0
        s_total_loss1 = 0
        s_loss_step = 0.
        w_total_loss0 = 0
        w_total_loss1 = 0
        w_loss_step = 0.

        batch_num = len(self.data_loader)
        count = (epoch + 1) * batch_num

        start_time = time.time()
        for step, data in enumerate(self.data_loader, 0):

            # self.save_raw_data(step, data)  # DEBUG

            self.text_encoder.zero_grad()
            self.image_encoder.zero_grad()

            imgs, captions, cap_lens, class_ids, keys = prepare_data(data)

            # words_features: batch_size x nef x 17 x 17
            # sent_code: batch_size x nef
            words_features, sent_code = self.image_encoder(imgs[-1])
            # --> batch_size x nef x 17*17
            nef, att_sze = words_features.size(1), words_features.size(2)
            # words_features = words_features.view(batch_size, nef, -1)

            hidden = self.text_encoder.init_hidden(self.batch_size)
            # words_emb: batch_size x nef x seq_len
            # sent_emb: batch_size x nef
            words_emb, sent_emb = self.text_encoder(captions, cap_lens, hidden)

            w_loss0, w_loss1, attn_maps = words_loss(words_features, words_emb, self.labels,
                                                     cap_lens, class_ids, self.batch_size)
            w_total_loss0 += w_loss0.data
            w_total_loss1 += w_loss1.data
            loss = w_loss0 + w_loss1

            s_loss0, s_loss1 = \
                sent_loss(sent_code, sent_emb, self.labels, class_ids, self.batch_size)
            loss += s_loss0 + s_loss1
            s_total_loss0 += s_loss0.data
            s_total_loss1 += s_loss1.data
            #
            w_loss_step += w_loss0.data + w_loss1.data
            s_loss_step += s_loss0.data + s_loss1.data
            loss.backward()
            #
            # `clip_grad_norm` helps prevent
            # the exploding gradient problem in RNNs / LSTMs.
            torch.nn.utils.clip_grad_norm_(self.text_encoder.parameters(), cfg.TRAIN.RNN_GRAD_CLIP)
            self.optimizer.step()

            if step % self.update_interval == 0:
                count = epoch * batch_num + step

                s_cur_loss0 = s_total_loss0[0] / self.update_interval
                s_cur_loss1 = s_total_loss1[0] / self.update_interval

                w_cur_loss0 = w_total_loss0[0] / self.update_interval
                w_cur_loss1 = w_total_loss1[0] / self.update_interval

                elapsed = time.time() - start_time
                self.log.add('| Epoch {:3d} | bt {:3d}/{:3d} | ms/bt {:5.2f} | S_loss {:2.4f} {:2.4f} | W_loss {:2.4f} {:5.4f} | Time {:5.2f}s'
                      .format(epoch,  step, len(self.data_loader),
                              elapsed * 1000. / self.update_interval,
                              s_cur_loss0, s_cur_loss1, w_cur_loss0, w_cur_loss1, elapsed))

                s_total_loss0 = 0
                s_total_loss1 = 0
                w_total_loss0 = 0
                w_total_loss1 = 0
                start_time = time.time()

                # Attention Maps
                img_set, _ = build_super_images(imgs[-1].cpu(), captions,
                                       self.ixtoword, attn_maps, att_sze)
                if img_set is not None:
                    im = Image.fromarray(img_set)
                    fullpath = '%s/attention_maps_%d_%d.png' % (self.image_dir, epoch, step)
                    im.save(fullpath)

        s_loss_step /= batch_num
        w_loss_step /= batch_num

        return count, s_loss_step, w_loss_step