def eval_coco(model, loader, run_name, log_dir, config):
    print("Computing coco-caption scores")

    # run coco-caption eval
    model.eval()

    with torch.no_grad():

        # compute predictions for each image
        predicted_captions = []

        for i, (image, image_id) in enumerate(loader):

            if config.beam_size < 2:
                # sample
                result = model.sample(image.to(device),
                                      greedy=config.eval_greedy,
                                      max_seq_len=loader.dataset.max_c_len + 1)
                caps, cap_lens = to_np(result.caption), to_np(
                    torch.sum(result.mask, dim=1))
            else:
                # beam search
                result = model.sample_beam(
                    image.to(device),
                    beam_size=config.beam_size,
                    max_seq_len=loader.dataset.max_c_len + 1)
                # subtract 1 to exclude <eos> token which is returned by beam search
                caps, cap_lens = to_np(result.best_caption), to_np(
                    torch.sum(result.best_logprobs != 0, dim=1) - 1)

            # append captions to list
            for j in range(image.size(0)):

                caption = ' '.join([
                    loader.dataset.c_i2w[x] for x in caps[j][:cap_lens[j]]
                ]).strip()

                pred = {'image_id': image_id[j].item(), 'caption': caption}

                predicted_captions.append(pred)

    # Reset
    model.train()
    return language_scores(predicted_captions,
                           run_name,
                           log_dir,
                           annFile=config.coco_annotation_file)
Пример #2
0
def build_cap_ref_dicts(pred, pred_len, targ, targ_len, eos_symbol, c_i2w):
    pred, pred_len, ref, ref_len = [
        to_np(x).tolist() for x in [pred, pred_len, targ, targ_len]
    ]

    gts = OrderedDict()
    for i in range(len(ref)):
        all_refs = []
        for j in range(len(ref[i])):
            all_refs.append(
                symbol_to_string(ref[i][j], ref_len[i][j], eos_symbol, c_i2w))
        gts[i] = all_refs

    res = [{
        'image_id':
        i,
        'caption': [symbol_to_string(pred[i], pred_len[i], eos_symbol, c_i2w)]
    } for i in range(len(pred))]

    gts = {i: gts[i] for i in range(len(gts))}
    return gts, res
Пример #3
0
    def collect_batch(self,
                      index,
                      captions,
                      masks,
                      rewards,
                      ans=None,
                      qs=None,
                      qlens=None,
                      topk=None):

        index, captions, lens = to_np(index), [to_np(x) for x in captions], [
            to_np(torch.sum(x, dim=1)) for x in masks
        ]
        if self.model == "QuestionAskingTrainer":
            ans, qs, qlens, topk = [to_np(x) for x in ans], [to_np(x) for x in qs],\
                                   [to_np(x) for x in qlens], [to_np(x) for x in topk]
        else:
            ans, qs, qlens, topk = [], [], [], []

        self.data_dump.append(
            [index, captions, lens, rewards, ans, qs, qlens, topk])
Пример #4
0
    def evaluate_with_questions(self):

        self.std_logger.info("Validating decision maker")

        self.dmaker.eval()
        self.captioner.train()
        self.qgen.train()
        # if self.opt.cap_eval:
        #     self.captioner.eval()
        # else:
        #     self.captioner.train()
        #
        # if self.opt.quegen_eval:
        #     self.question_generator.eval()
        # else:
        #     self.question_generator.train()

        c_i2w = self.val_loader.dataset.c_i2w

        correct, pos_correct, eval_scores = [0.0, 0.0], [0.0, 0.0], []
        caption_predictions = [[], []]

        with torch.no_grad():

            for step, sample in enumerate(self.val_loader):

                image, refs, target, caption_len = [
                    x.to(device)
                    for x in [sample[0], sample[4], sample[2], sample[3]]
                ]
                ref_lens, img_path, index, img_id = sample[5], sample[
                    7], sample[8], sample[9]

                batch_size = image.size(0)

                caps, cmasks, poses, pps, qs, qlps, qmasks, aps, cps, atts = [], [], [], [], [], [],[], [], [], []

                # 1. Caption completely
                self.set_seed(self.opt.seed)

                r = self.captioner.sample(
                    image,
                    greedy=True,
                    max_seq_len=self.opt.c_max_sentence_len + 1)

                caption, cap_probs, cap_mask, pos_probs, att, topk_words, attended_img \
                    = r.caption, r.prob, r.mask, r.pos_prob, r.attention.squeeze(), r.topk, r.atdimg

                cap_len = cap_mask.long().sum(dim=1)
                caps.append(caption)
                cmasks.append(cap_mask)
                poses.append(pos_probs)

                caption = self.pad_caption(caption, cap_len)

                # get the hidden state context
                source = torch.cat([
                    self.ones_vector[:batch_size].unsqueeze(1), caption[:, :-1]
                ],
                                   dim=1)

                r = self.fixed_caption_encoder(image,
                                               source,
                                               gt_pos=None,
                                               ss=False)
                h = r.hidden

                # 2. Identify the best time to ask a question, excluding ended sentences
                logit, valid_pos_mask = self.dmaker(
                    h, attended_img, caption, cap_len, pos_probs, topk_words,
                    self.captioner.caption_embedding.weight.data)
                masked_prob = masked_softmax(
                    logit,
                    cap_mask,
                    valid_pos_mask,
                    max_len=self.opt.c_max_sentence_len)
                dm_prob, ask_idx, ask_flag, ask_mask = self.sample_decision(
                    masked_prob, cap_mask, greedy=True)

                # 3. Ask the teacher a question and get the answer
                ans, ans_mask, r = self.ask_question(image,
                                                     caption,
                                                     refs,
                                                     pos_probs,
                                                     h,
                                                     att,
                                                     ask_idx,
                                                     q_greedy=True)

                # logging
                cps.append(cap_probs[self.range_vector[:batch_size], ask_idx])
                cps.append(cap_probs[self.range_vector[:batch_size], ask_idx])
                pps.append(r.pos_prob[0])
                pps.append(r.pos_prob[0])
                atts.append(r.att[0])
                atts.append(r.att[0])
                qlps.append(r.q_logprob.unsqueeze(1))
                qlps.append(r.q_logprob.unsqueeze(1))
                qmasks.append(r.q_mask)
                qmasks.append(r.q_mask)
                qs.append(r.question)
                qs.append(r.question)
                aps.append(r.ans_prob)
                aps.append(r.ans_prob)

                # 4. Compute new captions based on teacher's answer
                # rollout caption
                r = self.caption_with_teacher_answer(image,
                                                     ask_mask,
                                                     ans_mask,
                                                     greedy=True)

                poses.append(r.pos_prob)
                rollout = r.caption
                rollout_mask = r.cap_mask

                # replace caption
                replace = replace_word_in_caption(caps[0], ans, ask_idx,
                                                  ask_flag)

                base_rwd = mixed_reward(caps[0], torch.sum(cmasks[0],
                                                           dim=1), refs,
                                        ref_lens, self.scorers, self.c_i2w)
                rollout_rwd = mixed_reward(rollout,
                                           torch.sum(rollout_mask,
                                                     dim=1), refs, ref_lens,
                                           self.scorers, self.c_i2w)
                replace_rwd = mixed_reward(replace, torch.sum(cmasks[0],
                                                              dim=1), refs,
                                           ref_lens, self.scorers, self.c_i2w)

                caps.append(rollout)
                caps.append(replace)
                cmasks.append(rollout_mask)
                cmasks.append(cmasks[0])

                stat_rero, stat_rore, stat_reall, stat_roall = get_rollout_replace_stats(
                    replace_rwd, rollout_rwd, base_rwd)

                best_cap, best_cap_mask = choose_better_caption(
                    replace_rwd, replace, cmasks[0], rollout_rwd, rollout,
                    rollout_mask)

                caps = [caps[0], best_cap]
                cmasks = [cmasks[0], best_cap_mask]

                # Collect captions for coco evaluation
                img_id = util.to_np(img_id)

                for i in range(len(caps)):
                    words, lens, = util.to_np(caps[i]), util.to_np(
                        cmasks[i].sum(dim=1))

                    for j in range(image.size(0)):
                        inds = words[j][:lens[j]]
                        caption = ""
                        for k, ind in enumerate(inds):
                            if k > 0:
                                caption = caption + ' '
                            caption = caption + c_i2w[ind]

                        pred = {'image_id': img_id[j], 'caption': caption}
                        caption_predictions[i].append(pred)

                for i in range(len(correct)):
                    predictions = caps[i] * cmasks[i].long()
                    correct[i] += (
                        (target == predictions).float() /
                        caption_len.float().unsqueeze(1)).sum().item()

                # Logging
                if step % self.opt.val_print_every == 0:

                    c_i2w = self.val_loader.dataset.c_i2w
                    p_i2w = self.val_loader.dataset.p_i2w
                    q_i2w = self.val_loader.dataset.q_i2w
                    a_i2w = self.val_loader.dataset.a_i2w

                    caption, rollout, replace, cap_len, rollout_len, dec_probs, question, q_logprob, q_len, \
                    flag, raw_idx, refs, ref_lens = \
                        [util.to_np(x) for x in [caps[0], rollout, replace, cmasks[0].long().sum(dim=1),
                                                 rollout_mask.long().sum(dim=1), masked_prob, qs[0], qlps[0],
                                                 qmasks[0].long().sum(dim=1), ask_flag.long(),
                                                 ask_idx, refs, ref_lens]]
                    pos_probs, ans_probs, cap_probs = pps[0], aps[0], cps[0]

                    for i in range(image.size(0)):
                        top_pos = torch.topk(pos_probs, 3)[1]
                        top_ans = torch.topk(ans_probs[i], 3)[1]
                        top_cap = torch.topk(cap_probs[i], 5)[1]

                        word = c_i2w[caption[i][
                            raw_idx[i]]] if raw_idx[i] < len(
                                caption[i]) else 'Nan'

                        entry = {
                            'img':
                            img_path[i],
                            'epoch':
                            self.collection_epoch,
                            'caption':
                            ' '.join(
                                util.idx2str(c_i2w,
                                             (caption[i])[:cap_len[i]])) +
                            " | Reward: {:.2f}".format(base_rwd[i]),
                            'qaskprobs':
                            ' '.join([
                                "{:.2f}".format(x) if x > 0.0 else "_"
                                for x in dec_probs[i]
                            ]),
                            'rollout_caption':
                            ' '.join(
                                util.idx2str(c_i2w,
                                             (rollout[i])[:rollout_len[i]])) +
                            " | Reward: {:.2f}".format(rollout_rwd[i]),
                            'replace_caption':
                            ' '.join(
                                util.idx2str(c_i2w,
                                             (replace[i])[:cap_len[i]])) +
                            " | Reward: {:.2f}".format(replace_rwd[i]),
                            'index':
                            raw_idx[i],
                            'flag':
                            bool(flag[i]),
                            'word':
                            word,
                            'pos':
                            ' | '.join([
                                p_i2w[x.item()]
                                if x.item() in p_i2w else p_i2w[18]
                                for x in top_pos
                            ]),
                            'question':
                            ' '.join(
                                util.idx2str(q_i2w,
                                             (question[i])[:q_len[i]])) +
                            " | logprob: {}".format(
                                q_logprob[i, :q_len[i]].sum()),
                            'answers':
                            ' | '.join([a_i2w[x.item()] for x in top_ans]),
                            'words':
                            ' | '.join([c_i2w[x.item()] for x in top_cap]),
                            'refs': [
                                ' '.join(
                                    util.idx2str(c_i2w,
                                                 (refs[i, j])[:ref_lens[i,
                                                                        j]]))
                                for j in range(3)
                            ]
                        }

                        self.valLLvisualizer.add_entry(entry)

                    info = {
                        'eval/replace over rollout':
                        float(stat_rero) / (batch_size),
                        'eval/rollout over replace':
                        float(stat_rore) / (batch_size),
                        'eval/replace over all':
                        float(stat_reall) / (batch_size),
                        'eval/rollout over all':
                        float(stat_roall) / (batch_size),
                        'eval/question asking frequency (percent)':
                        float(ask_flag.sum().item()) / ask_flag.size(0)
                    }
                    util.step_logging(self.logger, info, self.eval_steps)
                    self.eval_steps += 1

        self.valLLvisualizer.update_html()

        acc = [x / len(self.val_loader.dataset) for x in correct]

        for i in range(len(correct)):
            eval_scores.append(
                language_scores(caption_predictions[i],
                                self.opt.run_name,
                                self.result_path,
                                annFile=self.opt.coco_annotation_file))

        self.dmaker.train()

        return acc, eval_scores
Пример #5
0
    def do_iteration(self, image, refs, ref_lens, index, img_path):
        self.d_optimizer.zero_grad()

        batch_size = image.size(0)

        caps, cmasks, qs, qlps, qmasks, aps, pps, atts, cps, dps = [], [], [], [], [], [], [], [], [], []

        # 1. Caption completely
        self.set_seed(self.opt.seed)

        r = self.captioner.sample(image,
                                  greedy=self.opt.cap_greedy,
                                  max_seq_len=self.opt.c_max_sentence_len + 1,
                                  temperature=self.opt.temperature)

        caption, cap_probs, cap_mask, pos_probs, att, topk_words, attended_img \
            = r.caption, r.prob, r.mask, r.pos_prob, r.attention.squeeze(), r.topk, r.atdimg

        # Don't backprop through captioner
        caption, cap_probs, cap_mask, pos_probs, attended_img \
            = [x.detach() for x in [caption, cap_probs, cap_mask, pos_probs, attended_img]]

        cap_len = cap_mask.long().sum(dim=1)
        caps.append(caption)
        cmasks.append(cap_mask)

        caption = self.pad_caption(caption, cap_len)

        # get the hidden state context
        source = torch.cat(
            [self.ones_vector[:batch_size].unsqueeze(1), caption[:, :-1]],
            dim=1)
        r = self.fixed_caption_encoder(image, source, gt_pos=None, ss=False)
        h = r.hidden.detach()
        topk_words = [[y.detach() for y in x] for x in topk_words]

        # 2. Identify the best time to ask a question, excluding ended sentences, baseline against the greedy decision
        logit, valid_pos_mask = self.dmaker(
            h, attended_img, caption, cap_len, pos_probs, topk_words,
            self.captioner.caption_embedding.weight.data)
        masked_prob = masked_softmax(logit,
                                     cap_mask,
                                     valid_pos_mask,
                                     self.opt.dm_temperature,
                                     max_len=self.opt.c_max_sentence_len)

        dm_prob, ask_idx, ask_flag, ask_mask = self.sample_decision(
            masked_prob, cap_mask, greedy=False)
        _, ask_idx_greedy, ask_flag_greedy, ask_mask_greedy = self.sample_decision(
            masked_prob, cap_mask, greedy=True)

        dps.append(dm_prob.unsqueeze(1))

        # 3. Ask the teacher a question and get the answer
        ans, ans_mask, r = self.ask_question(image,
                                             caption,
                                             refs,
                                             pos_probs,
                                             h,
                                             att,
                                             ask_idx,
                                             q_greedy=self.opt.q_greedy,
                                             temperature=self.opt.temperature)
        ans_greedy, ans_mask_greedy, rg = self.ask_question(
            image,
            caption,
            refs,
            pos_probs,
            h,
            att,
            ask_idx_greedy,
            q_greedy=self.opt.q_greedy,
            temperature=self.opt.temperature)

        # logging stuff
        cps.append(cap_probs[self.range_vector[:batch_size], ask_idx])
        cps.append(cap_probs[self.range_vector[:batch_size], ask_idx_greedy])
        pps.append(r.pos_prob[0])
        pps.append(rg.pos_prob[0])
        atts.append(r.att[0])
        atts.append(rg.att[0])
        qlps.append(r.q_logprob.unsqueeze(1))
        qlps.append(rg.q_logprob.unsqueeze(1))
        qmasks.append(r.q_mask)
        qmasks.append(rg.q_mask)
        qs.append(r.question.detach())
        qs.append(rg.question.detach())
        aps.append(r.ans_prob.detach())
        aps.append(rg.ans_prob.detach())

        # 4. Compute new captions based on teacher's answer
        # rollout caption
        r = self.caption_with_teacher_answer(image,
                                             ask_mask,
                                             ans_mask,
                                             greedy=self.opt.cap_greedy,
                                             temperature=self.opt.temperature)
        rg = self.caption_with_teacher_answer(image,
                                              ask_mask_greedy,
                                              ans_mask_greedy,
                                              greedy=self.opt.cap_greedy,
                                              temperature=self.opt.temperature)

        rollout, rollout_mask, rollout_greedy, rollout_mask_greedy = [
            x.detach()
            for x in [r.caption, r.cap_mask, rg.caption, rg.cap_mask]
        ]

        # replace caption
        replace = replace_word_in_caption(caps[0], ans, ask_idx, ask_flag)
        replace_greedy = replace_word_in_caption(caps[0], ans_greedy,
                                                 ask_idx_greedy,
                                                 ask_flag_greedy)

        # 5. Compute reward for captions
        base_rwd = mixed_reward(caps[0], torch.sum(cmasks[0], dim=1), refs,
                                ref_lens, self.scorers, self.c_i2w)
        rollout_rwd = mixed_reward(rollout, torch.sum(rollout_mask, dim=1),
                                   refs, ref_lens, self.scorers, self.c_i2w)
        rollout_greedy_rwd = mixed_reward(
            rollout_greedy, torch.sum(rollout_mask_greedy, dim=1), refs,
            ref_lens, self.scorers, self.c_i2w)

        replace_rwd = mixed_reward(replace, torch.sum(cmasks[0], dim=1), refs,
                                   ref_lens, self.scorers, self.c_i2w)
        replace_greedy_rwd = mixed_reward(replace_greedy,
                                          torch.sum(cmasks[0], dim=1), refs,
                                          ref_lens, self.scorers, self.c_i2w)

        rwd = np.maximum(replace_rwd, rollout_rwd)
        rwd_greedy = np.maximum(replace_greedy_rwd, rollout_greedy_rwd)

        best_cap, best_cap_mask = choose_better_caption(
            replace_rwd, replace, cmasks[0], rollout_rwd, rollout,
            rollout_mask)
        best_cap_greedy, best_cap_greedy_mask = choose_better_caption(
            replace_greedy_rwd, replace_greedy, cmasks[0], rollout_greedy_rwd,
            rollout_greedy, rollout_mask_greedy)

        # some statistics on whether rollout or single-word-replace is better
        stat_rero, stat_rore, stat_reall, stat_roall = get_rollout_replace_stats(
            replace_rwd, rollout_rwd, base_rwd)
        caps.append(best_cap)
        cmasks.append(best_cap_mask)
        caps.append(best_cap_greedy)
        cmasks.append(best_cap_greedy_mask)

        # Backwards pass to train decision maker with policy gradient loss
        reward_delta = torch.from_numpy(rwd - rwd_greedy).type(
            torch.float).to(device)
        reward_delta = reward_delta - self.opt.ask_penalty * ask_flag.float()

        loss = masked_PG(reward_delta.detach(),
                         torch.log(dps[0]).squeeze(), ask_flag.detach())
        loss.backward()

        self.d_optimizer.step()

        # also save the question asked and answer, and top-k predictions from captioner
        answers = [torch.max(x, dim=1)[1] for x in aps]
        topwords = [torch.topk(x, 20)[1] for x in cps]
        question_lens = [x.sum(dim=1) for x in qmasks]
        self.data_collector.collect_batch(index.clone(), caps, cmasks,
                                          [base_rwd, rwd, rwd_greedy], answers,
                                          qs, question_lens, topwords)

        # Logging
        if self.collection_steps % self.opt.print_every == 0:

            c_i2w = self.val_loader.dataset.c_i2w
            p_i2w = self.val_loader.dataset.p_i2w
            q_i2w = self.val_loader.dataset.q_i2w
            a_i2w = self.val_loader.dataset.a_i2w

            caption, rollout, replace, cap_len, rollout_len, dec_probs, question, q_logprob, q_len,\
            flag, raw_idx, refs, ref_lens = \
                [util.to_np(x) for x in [caps[0], rollout, replace, cmasks[0].long().sum(dim=1),
                                         rollout_mask.long().sum(dim=1), masked_prob, qs[0], qlps[0],
                                         qmasks[0].long().sum(dim=1), ask_flag.long(),
                                         ask_idx, refs, ref_lens]]
            pos_probs, ans_probs, cap_probs = pps[0], aps[0], cps[0]

            for i in range(image.size(0)):

                top_pos = torch.topk(pos_probs, 3)[1]
                top_ans = torch.topk(ans_probs[i], 3)[1]
                top_cap = torch.topk(cap_probs[i], 5)[1]

                word = c_i2w[caption[i][raw_idx[i]]] if raw_idx[i] < len(
                    caption[i]) else 'Nan'

                entry = {
                    'img':
                    img_path[i],
                    'epoch':
                    self.collection_epoch,
                    'caption':
                    ' '.join(util.idx2str(c_i2w, (caption[i])[:cap_len[i]])) +
                    " | Reward: {:.2f}".format(base_rwd[i]),
                    'qaskprobs':
                    ' '.join([
                        "{:.2f}".format(x) if x > 0.0 else "_"
                        for x in dec_probs[i]
                    ]),
                    'rollout_caption':
                    ' '.join(util.idx2str(c_i2w,
                                          (rollout[i])[:rollout_len[i]])) +
                    " | Reward: {:.2f}".format(rollout_rwd[i]),
                    'replace_caption':
                    ' '.join(util.idx2str(c_i2w, (replace[i])[:cap_len[i]])) +
                    " | Reward: {:.2f}".format(replace_rwd[i]),
                    'index':
                    raw_idx[i],
                    'flag':
                    bool(flag[i]),
                    'word':
                    word,
                    'pos':
                    ' | '.join([
                        p_i2w[x.item()] if x.item() in p_i2w else p_i2w[18]
                        for x in top_pos
                    ]),
                    'question':
                    ' '.join(util.idx2str(q_i2w, (question[i])[:q_len[i]])) +
                    " | logprob: {}".format(q_logprob[i, :q_len[i]].sum()),
                    'answers':
                    ' | '.join([a_i2w[x.item()] for x in top_ans]),
                    'words':
                    ' | '.join([c_i2w[x.item()] for x in top_cap]),
                    'refs': [
                        ' '.join(
                            util.idx2str(c_i2w, (refs[i, j])[:ref_lens[i, j]]))
                        for j in range(3)
                    ]
                }

                self.trainLLvisualizer.add_entry(entry)

                info = {
                    'collect/question logprob':
                    torch.mean(torch.cat(qlps, dim=1)).item(),
                    'collect/replace over rollout':
                    float(stat_rero) / (batch_size),
                    'collect/rollout over replace':
                    float(stat_rore) / (batch_size),
                    'collect/replace over all':
                    float(stat_reall) / (batch_size),
                    'collect/rollout over all':
                    float(stat_roall) / (batch_size),
                    'collect/sampled decision equals greedy decision':
                    float((ask_idx == ask_idx_greedy).sum().item()) /
                    (batch_size),
                    'collect/question asking frequency (percent)':
                    float(ask_flag.sum().item()) / ask_flag.size(0)
                }
                util.step_logging(self.logger, info, self.collection_steps)

        return loss.item()
    def validate(self):

        print("Validating")

        self.model.eval()
        loss, correct = 0.0, 0.0
        samples = []

        for step, sample in enumerate(self.val_loader):

            image, question, question_len, answer, captions, cap_lens, img_path, que_id = sample
            image, question, answer, captions = [
                x.to(device) for x in [image, question, answer, captions]
            ]

            # Forward pass
            result = self.model(image, question, captions)
            probs, logits = result.probs, result.logits

            loss += self.loss_function(logits, answer).item()

            # Compute top answer accuracy
            _, prediction_max_index = torch.max(probs, 1)
            _, answer_max_index = torch.max(answer, 1)

            correct += (
                answer_max_index == prediction_max_index).float().sum().item()

            a_i2w = self.val_loader.dataset.a_i2w
            q_i2w = self.val_loader.dataset.q_i2w
            c_i2w = self.val_loader.dataset.c_i2w

            # Append prediction to VQA2.0 validation
            for i in range(image.size(0)):
                samples.append({
                    'question_id': que_id[i].item(),
                    'answer': a_i2w[prediction_max_index[i].item()]
                })

            # write image and Q&A to html file to visualize training progress
            if step % self.opt.visualize_every == 0:

                # Show the top 3 predicted answers
                pros, ans = torch.topk(probs, k=3, dim=1)
                captions, cap_lens, question, question_len, pros, ans =\
                    [util.to_np(x) for x in [captions, cap_lens, question, question_len, pros, ans]]

                for i in range(image.size(0) / 2):

                    # Show question
                    que_arr = util.idx2str(q_i2w,
                                           (question[i])[:question_len[i]])

                    entry = {
                        'img':
                        img_path[i],
                        'epoch':
                        self.epoch,
                        'question':
                        ' '.join(que_arr),
                        'gt_ans':
                        a_i2w[answer_max_index[i].item()],
                        'predictions':
                        [[p, a_i2w[a]] for p, a in zip(pros[i], ans[i])],
                        'refs': [
                            ' '.join(
                                util.idx2str(c_i2w,
                                             (captions[i, j])[:cap_lens[i,
                                                                        j]]))
                            for j in range(3)
                        ]
                    }
                    self.visualizer.add_entry(entry)

        self.visualizer.update_html()
        # Reset
        self.model.train()

        return samples, [
            x / len(self.val_loader.dataset) for x in [loss, correct]
        ]
Пример #7
0
    def validate_captioner(self):

        self.captioner.eval()
        word_loss, word_cor, pos_cor, pos_loss = 0.0, 0.0, 0.0, 0.0

        with torch.no_grad():
            for step, sample in enumerate(self.val_loader):
                image, source, target, caption_len, refs, ref_lens, pos, img_path = sample[:
                                                                                           -2]
                sample = [
                    x.to(device)
                    for x in [image, source, target, caption_len, pos]
                ]
                r = validate_helper(sample, self.captioner,
                                    self.val_loader.dataset.max_c_len)

                word_loss, word_cor, pos_cor, pos_loss = [
                    x + y for x, y in zip(
                        [word_loss, word_cor, pos_cor, pos_loss], r)
                ]

                # write image and predicted captions to html file to visualize training progress
                if step % self.opt.val_print_every == 0:
                    beam_size = 3
                    c_i2w = self.val_loader.dataset.c_i2w
                    p_i2w = self.val_loader.dataset.p_i2w
                    refs, ref_lens = [util.to_np(x) for x in [refs, ref_lens]]

                    # show top 3 beam search captions
                    result = self.captioner.sample_beam(sample[0],
                                                        beam_size=beam_size,
                                                        max_seq_len=17)
                    captions, lps, lens = [
                        util.to_np(x) for x in [
                            result.captions,
                            torch.sum(result.logprobs, dim=2),
                            torch.sum(result.logprobs.abs() > 0.00001, dim=2)
                        ]
                    ]

                    # show greedy caption
                    result = self.captioner.sample(sample[0],
                                                   greedy=True,
                                                   max_seq_len=17)
                    pprob, pidx = torch.max(result.pos_prob, dim=2)
                    gcap, glp, glens, pidx, pprob = [
                        util.to_np(x) for x in [
                            result.caption,
                            torch.sum(result.log_prob, dim=1),
                            torch.sum(result.mask, dim=1), pidx, pprob
                        ]
                    ]

                    for i in range(image.size(0)):
                        cap_arr = util.idx2str(c_i2w, (gcap[i])[:glens[i]])
                        pos_arr = util.idx2pos(p_i2w, (pidx[i])[:glens[i]])
                        pos_pred = [
                            "{} ({} {:.2f})".format(cap_arr[j], pos_arr[j],
                                                    pprob[i, j])
                            for j in range(glens[i])
                        ]

                        entry = {
                            'img':
                            img_path[i],
                            'epoch':
                            self.cap_epoch,
                            'greedy_sample':
                            ' '.join(cap_arr) + " logprob: {}".format(glp[i]),
                            'pos_pred':
                            ' '.join(pos_pred),
                            'beamsearch': [
                                ' '.join(
                                    util.idx2str(c_i2w,
                                                 (captions[i, j])[:lens[i,
                                                                        j]])) +
                                " logprob: {}".format(lps[i, j])
                                for j in range(beam_size)
                            ],
                            'refs': [
                                ' '.join(
                                    util.idx2str(c_i2w,
                                                 (refs[i, j])[:ref_lens[i,
                                                                        j]]))
                                for j in range(3)
                            ]
                        }

                        self.Cvisualizer.add_entry(entry)

        self.Cvisualizer.update_html()
        # Reset
        self.captioner.train()

        return [
            x / len(self.val_loader.dataset)
            for x in [word_loss, word_cor, pos_cor, pos_loss]
        ]
    def validate(self):

        loss, correct, correct_answers = 0.0, 0.0, 0
        self.model.eval()

        for step, batch in enumerate(self.val_loader):
            img_path = batch[-1]
            batch = [x.to(device) for x in batch[:-1]]

            image, question_len, source, target, caption, q_idx_vec, pos, att, context, refs, answer = self.compute_cap_features_val(
                batch)

            logits = self.model(image, caption, pos, context, att, source,
                                q_idx_vec)

            batch_loss = masked_CE(logits, target, question_len)
            loss += batch_loss.item()
            predictions = seq_max_and_mask(
                logits, question_len, self.val_loader.dataset.max_q_len + 1)

            correct += torch.sum((target == predictions).float() /
                                 question_len.float().unsqueeze(1)).item()

            # evaluate using VQA expert
            result = self.model.sample(
                image,
                caption,
                pos,
                context,
                att,
                q_idx_vec,
                greedy=True,
                max_seq_len=self.val_loader.dataset.max_q_len + 1)
            sample, log_probs, mask = result.question, result.log_prob, result.mask

            correct_answers += self.query_vqa(image, sample, refs, answer,
                                              mask)

            # write image and Q&A to html file to visualize training progress
            if step % self.opt.visualize_every == 0:
                beam_size = 3
                c_i2w = self.val_loader.dataset.c_i2w
                a_i2w = self.val_loader.dataset.a_i2w
                q_i2w = self.val_loader.dataset.q_i2w

                sample_len = mask.sum(dim=1)

                _, _, beam_predictions, beam_lps = self.model.sample_beam(
                    image,
                    caption,
                    pos,
                    context,
                    att,
                    q_idx_vec,
                    beam_size=beam_size,
                    max_seq_len=15)
                beam_predictions, lps, lens = [
                    util.to_np(x) for x in [
                        beam_predictions,
                        torch.sum(beam_lps, dim=2),
                        torch.sum(beam_lps != 0, dim=2)
                    ]
                ]

                target, question_len, sample, sample_len = [
                    util.to_np(x)
                    for x in [target, question_len, sample, sample_len]
                ]

                for i in range(image.size(0)):

                    entry = {
                        'img':
                        img_path[i],
                        'epoch':
                        self.epoch,
                        'answer':
                        a_i2w[answer[i].item()],
                        'gt_question':
                        ' '.join(
                            util.idx2str(q_i2w,
                                         (target[i])[:question_len[i] - 1])),
                        'greedy_question':
                        ' '.join(
                            util.idx2str(q_i2w, (sample[i])[:sample_len[i]])),
                        'beamsearch': [
                            ' '.join(
                                util.idx2str(
                                    q_i2w,
                                    (beam_predictions[i, j])[:lens[i, j]])) +
                            " logprob: {}".format(lps[i, j])
                            for j in range(beam_size)
                        ]
                    }
                    self.visualizer.add_entry(entry)

        self.visualizer.update_html()

        # Reset
        self.model.train()

        l = len(self.val_loader.dataset)
        return [loss / l, correct / l, 100.0 * correct_answers / l]