示例#1
0
    def do_iteration(self, image, refs, ref_lens, index, img_path):
        self.d_optimizer.zero_grad()

        batch_size = image.size(0)

        caps, cmasks, qs, qlps, qmasks, aps, pps, atts, cps, dps = [], [], [], [], [], [], [], [], [], []

        # 1. Caption completely
        self.set_seed(self.opt.seed)

        r = self.captioner.sample(image,
                                  greedy=self.opt.cap_greedy,
                                  max_seq_len=self.opt.c_max_sentence_len + 1,
                                  temperature=self.opt.temperature)

        caption, cap_probs, cap_mask, pos_probs, att, topk_words, attended_img \
            = r.caption, r.prob, r.mask, r.pos_prob, r.attention.squeeze(), r.topk, r.atdimg

        # Don't backprop through captioner
        caption, cap_probs, cap_mask, pos_probs, attended_img \
            = [x.detach() for x in [caption, cap_probs, cap_mask, pos_probs, attended_img]]

        cap_len = cap_mask.long().sum(dim=1)
        caps.append(caption)
        cmasks.append(cap_mask)

        caption = self.pad_caption(caption, cap_len)

        # get the hidden state context
        source = torch.cat(
            [self.ones_vector[:batch_size].unsqueeze(1), caption[:, :-1]],
            dim=1)
        r = self.fixed_caption_encoder(image, source, gt_pos=None, ss=False)
        h = r.hidden.detach()
        topk_words = [[y.detach() for y in x] for x in topk_words]

        # 2. Identify the best time to ask a question, excluding ended sentences, baseline against the greedy decision
        logit, valid_pos_mask = self.dmaker(
            h, attended_img, caption, cap_len, pos_probs, topk_words,
            self.captioner.caption_embedding.weight.data)
        masked_prob = masked_softmax(logit,
                                     cap_mask,
                                     valid_pos_mask,
                                     self.opt.dm_temperature,
                                     max_len=self.opt.c_max_sentence_len)

        dm_prob, ask_idx, ask_flag, ask_mask = self.sample_decision(
            masked_prob, cap_mask, greedy=False)
        _, ask_idx_greedy, ask_flag_greedy, ask_mask_greedy = self.sample_decision(
            masked_prob, cap_mask, greedy=True)

        dps.append(dm_prob.unsqueeze(1))

        # 3. Ask the teacher a question and get the answer
        ans, ans_mask, r = self.ask_question(image,
                                             caption,
                                             refs,
                                             pos_probs,
                                             h,
                                             att,
                                             ask_idx,
                                             q_greedy=self.opt.q_greedy,
                                             temperature=self.opt.temperature)
        ans_greedy, ans_mask_greedy, rg = self.ask_question(
            image,
            caption,
            refs,
            pos_probs,
            h,
            att,
            ask_idx_greedy,
            q_greedy=self.opt.q_greedy,
            temperature=self.opt.temperature)

        # logging stuff
        cps.append(cap_probs[self.range_vector[:batch_size], ask_idx])
        cps.append(cap_probs[self.range_vector[:batch_size], ask_idx_greedy])
        pps.append(r.pos_prob[0])
        pps.append(rg.pos_prob[0])
        atts.append(r.att[0])
        atts.append(rg.att[0])
        qlps.append(r.q_logprob.unsqueeze(1))
        qlps.append(rg.q_logprob.unsqueeze(1))
        qmasks.append(r.q_mask)
        qmasks.append(rg.q_mask)
        qs.append(r.question.detach())
        qs.append(rg.question.detach())
        aps.append(r.ans_prob.detach())
        aps.append(rg.ans_prob.detach())

        # 4. Compute new captions based on teacher's answer
        # rollout caption
        r = self.caption_with_teacher_answer(image,
                                             ask_mask,
                                             ans_mask,
                                             greedy=self.opt.cap_greedy,
                                             temperature=self.opt.temperature)
        rg = self.caption_with_teacher_answer(image,
                                              ask_mask_greedy,
                                              ans_mask_greedy,
                                              greedy=self.opt.cap_greedy,
                                              temperature=self.opt.temperature)

        rollout, rollout_mask, rollout_greedy, rollout_mask_greedy = [
            x.detach()
            for x in [r.caption, r.cap_mask, rg.caption, rg.cap_mask]
        ]

        # replace caption
        replace = replace_word_in_caption(caps[0], ans, ask_idx, ask_flag)
        replace_greedy = replace_word_in_caption(caps[0], ans_greedy,
                                                 ask_idx_greedy,
                                                 ask_flag_greedy)

        # 5. Compute reward for captions
        base_rwd = mixed_reward(caps[0], torch.sum(cmasks[0], dim=1), refs,
                                ref_lens, self.scorers, self.c_i2w)
        rollout_rwd = mixed_reward(rollout, torch.sum(rollout_mask, dim=1),
                                   refs, ref_lens, self.scorers, self.c_i2w)
        rollout_greedy_rwd = mixed_reward(
            rollout_greedy, torch.sum(rollout_mask_greedy, dim=1), refs,
            ref_lens, self.scorers, self.c_i2w)

        replace_rwd = mixed_reward(replace, torch.sum(cmasks[0], dim=1), refs,
                                   ref_lens, self.scorers, self.c_i2w)
        replace_greedy_rwd = mixed_reward(replace_greedy,
                                          torch.sum(cmasks[0], dim=1), refs,
                                          ref_lens, self.scorers, self.c_i2w)

        rwd = np.maximum(replace_rwd, rollout_rwd)
        rwd_greedy = np.maximum(replace_greedy_rwd, rollout_greedy_rwd)

        best_cap, best_cap_mask = choose_better_caption(
            replace_rwd, replace, cmasks[0], rollout_rwd, rollout,
            rollout_mask)
        best_cap_greedy, best_cap_greedy_mask = choose_better_caption(
            replace_greedy_rwd, replace_greedy, cmasks[0], rollout_greedy_rwd,
            rollout_greedy, rollout_mask_greedy)

        # some statistics on whether rollout or single-word-replace is better
        stat_rero, stat_rore, stat_reall, stat_roall = get_rollout_replace_stats(
            replace_rwd, rollout_rwd, base_rwd)
        caps.append(best_cap)
        cmasks.append(best_cap_mask)
        caps.append(best_cap_greedy)
        cmasks.append(best_cap_greedy_mask)

        # Backwards pass to train decision maker with policy gradient loss
        reward_delta = torch.from_numpy(rwd - rwd_greedy).type(
            torch.float).to(device)
        reward_delta = reward_delta - self.opt.ask_penalty * ask_flag.float()

        loss = masked_PG(reward_delta.detach(),
                         torch.log(dps[0]).squeeze(), ask_flag.detach())
        loss.backward()

        self.d_optimizer.step()

        # also save the question asked and answer, and top-k predictions from captioner
        answers = [torch.max(x, dim=1)[1] for x in aps]
        topwords = [torch.topk(x, 20)[1] for x in cps]
        question_lens = [x.sum(dim=1) for x in qmasks]
        self.data_collector.collect_batch(index.clone(), caps, cmasks,
                                          [base_rwd, rwd, rwd_greedy], answers,
                                          qs, question_lens, topwords)

        # Logging
        if self.collection_steps % self.opt.print_every == 0:

            c_i2w = self.val_loader.dataset.c_i2w
            p_i2w = self.val_loader.dataset.p_i2w
            q_i2w = self.val_loader.dataset.q_i2w
            a_i2w = self.val_loader.dataset.a_i2w

            caption, rollout, replace, cap_len, rollout_len, dec_probs, question, q_logprob, q_len,\
            flag, raw_idx, refs, ref_lens = \
                [util.to_np(x) for x in [caps[0], rollout, replace, cmasks[0].long().sum(dim=1),
                                         rollout_mask.long().sum(dim=1), masked_prob, qs[0], qlps[0],
                                         qmasks[0].long().sum(dim=1), ask_flag.long(),
                                         ask_idx, refs, ref_lens]]
            pos_probs, ans_probs, cap_probs = pps[0], aps[0], cps[0]

            for i in range(image.size(0)):

                top_pos = torch.topk(pos_probs, 3)[1]
                top_ans = torch.topk(ans_probs[i], 3)[1]
                top_cap = torch.topk(cap_probs[i], 5)[1]

                word = c_i2w[caption[i][raw_idx[i]]] if raw_idx[i] < len(
                    caption[i]) else 'Nan'

                entry = {
                    'img':
                    img_path[i],
                    'epoch':
                    self.collection_epoch,
                    'caption':
                    ' '.join(util.idx2str(c_i2w, (caption[i])[:cap_len[i]])) +
                    " | Reward: {:.2f}".format(base_rwd[i]),
                    'qaskprobs':
                    ' '.join([
                        "{:.2f}".format(x) if x > 0.0 else "_"
                        for x in dec_probs[i]
                    ]),
                    'rollout_caption':
                    ' '.join(util.idx2str(c_i2w,
                                          (rollout[i])[:rollout_len[i]])) +
                    " | Reward: {:.2f}".format(rollout_rwd[i]),
                    'replace_caption':
                    ' '.join(util.idx2str(c_i2w, (replace[i])[:cap_len[i]])) +
                    " | Reward: {:.2f}".format(replace_rwd[i]),
                    'index':
                    raw_idx[i],
                    'flag':
                    bool(flag[i]),
                    'word':
                    word,
                    'pos':
                    ' | '.join([
                        p_i2w[x.item()] if x.item() in p_i2w else p_i2w[18]
                        for x in top_pos
                    ]),
                    'question':
                    ' '.join(util.idx2str(q_i2w, (question[i])[:q_len[i]])) +
                    " | logprob: {}".format(q_logprob[i, :q_len[i]].sum()),
                    'answers':
                    ' | '.join([a_i2w[x.item()] for x in top_ans]),
                    'words':
                    ' | '.join([c_i2w[x.item()] for x in top_cap]),
                    'refs': [
                        ' '.join(
                            util.idx2str(c_i2w, (refs[i, j])[:ref_lens[i, j]]))
                        for j in range(3)
                    ]
                }

                self.trainLLvisualizer.add_entry(entry)

                info = {
                    'collect/question logprob':
                    torch.mean(torch.cat(qlps, dim=1)).item(),
                    'collect/replace over rollout':
                    float(stat_rero) / (batch_size),
                    'collect/rollout over replace':
                    float(stat_rore) / (batch_size),
                    'collect/replace over all':
                    float(stat_reall) / (batch_size),
                    'collect/rollout over all':
                    float(stat_roall) / (batch_size),
                    'collect/sampled decision equals greedy decision':
                    float((ask_idx == ask_idx_greedy).sum().item()) /
                    (batch_size),
                    'collect/question asking frequency (percent)':
                    float(ask_flag.sum().item()) / ask_flag.size(0)
                }
                util.step_logging(self.logger, info, self.collection_steps)

        return loss.item()
示例#2
0
    def evaluate_with_questions(self):

        self.std_logger.info("Validating decision maker")

        self.dmaker.eval()
        self.captioner.train()
        self.qgen.train()
        # if self.opt.cap_eval:
        #     self.captioner.eval()
        # else:
        #     self.captioner.train()
        #
        # if self.opt.quegen_eval:
        #     self.question_generator.eval()
        # else:
        #     self.question_generator.train()

        c_i2w = self.val_loader.dataset.c_i2w

        correct, pos_correct, eval_scores = [0.0, 0.0], [0.0, 0.0], []
        caption_predictions = [[], []]

        with torch.no_grad():

            for step, sample in enumerate(self.val_loader):

                image, refs, target, caption_len = [
                    x.to(device)
                    for x in [sample[0], sample[4], sample[2], sample[3]]
                ]
                ref_lens, img_path, index, img_id = sample[5], sample[
                    7], sample[8], sample[9]

                batch_size = image.size(0)

                caps, cmasks, poses, pps, qs, qlps, qmasks, aps, cps, atts = [], [], [], [], [], [],[], [], [], []

                # 1. Caption completely
                self.set_seed(self.opt.seed)

                r = self.captioner.sample(
                    image,
                    greedy=True,
                    max_seq_len=self.opt.c_max_sentence_len + 1)

                caption, cap_probs, cap_mask, pos_probs, att, topk_words, attended_img \
                    = r.caption, r.prob, r.mask, r.pos_prob, r.attention.squeeze(), r.topk, r.atdimg

                cap_len = cap_mask.long().sum(dim=1)
                caps.append(caption)
                cmasks.append(cap_mask)
                poses.append(pos_probs)

                caption = self.pad_caption(caption, cap_len)

                # get the hidden state context
                source = torch.cat([
                    self.ones_vector[:batch_size].unsqueeze(1), caption[:, :-1]
                ],
                                   dim=1)

                r = self.fixed_caption_encoder(image,
                                               source,
                                               gt_pos=None,
                                               ss=False)
                h = r.hidden

                # 2. Identify the best time to ask a question, excluding ended sentences
                logit, valid_pos_mask = self.dmaker(
                    h, attended_img, caption, cap_len, pos_probs, topk_words,
                    self.captioner.caption_embedding.weight.data)
                masked_prob = masked_softmax(
                    logit,
                    cap_mask,
                    valid_pos_mask,
                    max_len=self.opt.c_max_sentence_len)
                dm_prob, ask_idx, ask_flag, ask_mask = self.sample_decision(
                    masked_prob, cap_mask, greedy=True)

                # 3. Ask the teacher a question and get the answer
                ans, ans_mask, r = self.ask_question(image,
                                                     caption,
                                                     refs,
                                                     pos_probs,
                                                     h,
                                                     att,
                                                     ask_idx,
                                                     q_greedy=True)

                # logging
                cps.append(cap_probs[self.range_vector[:batch_size], ask_idx])
                cps.append(cap_probs[self.range_vector[:batch_size], ask_idx])
                pps.append(r.pos_prob[0])
                pps.append(r.pos_prob[0])
                atts.append(r.att[0])
                atts.append(r.att[0])
                qlps.append(r.q_logprob.unsqueeze(1))
                qlps.append(r.q_logprob.unsqueeze(1))
                qmasks.append(r.q_mask)
                qmasks.append(r.q_mask)
                qs.append(r.question)
                qs.append(r.question)
                aps.append(r.ans_prob)
                aps.append(r.ans_prob)

                # 4. Compute new captions based on teacher's answer
                # rollout caption
                r = self.caption_with_teacher_answer(image,
                                                     ask_mask,
                                                     ans_mask,
                                                     greedy=True)

                poses.append(r.pos_prob)
                rollout = r.caption
                rollout_mask = r.cap_mask

                # replace caption
                replace = replace_word_in_caption(caps[0], ans, ask_idx,
                                                  ask_flag)

                base_rwd = mixed_reward(caps[0], torch.sum(cmasks[0],
                                                           dim=1), refs,
                                        ref_lens, self.scorers, self.c_i2w)
                rollout_rwd = mixed_reward(rollout,
                                           torch.sum(rollout_mask,
                                                     dim=1), refs, ref_lens,
                                           self.scorers, self.c_i2w)
                replace_rwd = mixed_reward(replace, torch.sum(cmasks[0],
                                                              dim=1), refs,
                                           ref_lens, self.scorers, self.c_i2w)

                caps.append(rollout)
                caps.append(replace)
                cmasks.append(rollout_mask)
                cmasks.append(cmasks[0])

                stat_rero, stat_rore, stat_reall, stat_roall = get_rollout_replace_stats(
                    replace_rwd, rollout_rwd, base_rwd)

                best_cap, best_cap_mask = choose_better_caption(
                    replace_rwd, replace, cmasks[0], rollout_rwd, rollout,
                    rollout_mask)

                caps = [caps[0], best_cap]
                cmasks = [cmasks[0], best_cap_mask]

                # Collect captions for coco evaluation
                img_id = util.to_np(img_id)

                for i in range(len(caps)):
                    words, lens, = util.to_np(caps[i]), util.to_np(
                        cmasks[i].sum(dim=1))

                    for j in range(image.size(0)):
                        inds = words[j][:lens[j]]
                        caption = ""
                        for k, ind in enumerate(inds):
                            if k > 0:
                                caption = caption + ' '
                            caption = caption + c_i2w[ind]

                        pred = {'image_id': img_id[j], 'caption': caption}
                        caption_predictions[i].append(pred)

                for i in range(len(correct)):
                    predictions = caps[i] * cmasks[i].long()
                    correct[i] += (
                        (target == predictions).float() /
                        caption_len.float().unsqueeze(1)).sum().item()

                # Logging
                if step % self.opt.val_print_every == 0:

                    c_i2w = self.val_loader.dataset.c_i2w
                    p_i2w = self.val_loader.dataset.p_i2w
                    q_i2w = self.val_loader.dataset.q_i2w
                    a_i2w = self.val_loader.dataset.a_i2w

                    caption, rollout, replace, cap_len, rollout_len, dec_probs, question, q_logprob, q_len, \
                    flag, raw_idx, refs, ref_lens = \
                        [util.to_np(x) for x in [caps[0], rollout, replace, cmasks[0].long().sum(dim=1),
                                                 rollout_mask.long().sum(dim=1), masked_prob, qs[0], qlps[0],
                                                 qmasks[0].long().sum(dim=1), ask_flag.long(),
                                                 ask_idx, refs, ref_lens]]
                    pos_probs, ans_probs, cap_probs = pps[0], aps[0], cps[0]

                    for i in range(image.size(0)):
                        top_pos = torch.topk(pos_probs, 3)[1]
                        top_ans = torch.topk(ans_probs[i], 3)[1]
                        top_cap = torch.topk(cap_probs[i], 5)[1]

                        word = c_i2w[caption[i][
                            raw_idx[i]]] if raw_idx[i] < len(
                                caption[i]) else 'Nan'

                        entry = {
                            'img':
                            img_path[i],
                            'epoch':
                            self.collection_epoch,
                            'caption':
                            ' '.join(
                                util.idx2str(c_i2w,
                                             (caption[i])[:cap_len[i]])) +
                            " | Reward: {:.2f}".format(base_rwd[i]),
                            'qaskprobs':
                            ' '.join([
                                "{:.2f}".format(x) if x > 0.0 else "_"
                                for x in dec_probs[i]
                            ]),
                            'rollout_caption':
                            ' '.join(
                                util.idx2str(c_i2w,
                                             (rollout[i])[:rollout_len[i]])) +
                            " | Reward: {:.2f}".format(rollout_rwd[i]),
                            'replace_caption':
                            ' '.join(
                                util.idx2str(c_i2w,
                                             (replace[i])[:cap_len[i]])) +
                            " | Reward: {:.2f}".format(replace_rwd[i]),
                            'index':
                            raw_idx[i],
                            'flag':
                            bool(flag[i]),
                            'word':
                            word,
                            'pos':
                            ' | '.join([
                                p_i2w[x.item()]
                                if x.item() in p_i2w else p_i2w[18]
                                for x in top_pos
                            ]),
                            'question':
                            ' '.join(
                                util.idx2str(q_i2w,
                                             (question[i])[:q_len[i]])) +
                            " | logprob: {}".format(
                                q_logprob[i, :q_len[i]].sum()),
                            'answers':
                            ' | '.join([a_i2w[x.item()] for x in top_ans]),
                            'words':
                            ' | '.join([c_i2w[x.item()] for x in top_cap]),
                            'refs': [
                                ' '.join(
                                    util.idx2str(c_i2w,
                                                 (refs[i, j])[:ref_lens[i,
                                                                        j]]))
                                for j in range(3)
                            ]
                        }

                        self.valLLvisualizer.add_entry(entry)

                    info = {
                        'eval/replace over rollout':
                        float(stat_rero) / (batch_size),
                        'eval/rollout over replace':
                        float(stat_rore) / (batch_size),
                        'eval/replace over all':
                        float(stat_reall) / (batch_size),
                        'eval/rollout over all':
                        float(stat_roall) / (batch_size),
                        'eval/question asking frequency (percent)':
                        float(ask_flag.sum().item()) / ask_flag.size(0)
                    }
                    util.step_logging(self.logger, info, self.eval_steps)
                    self.eval_steps += 1

        self.valLLvisualizer.update_html()

        acc = [x / len(self.val_loader.dataset) for x in correct]

        for i in range(len(correct)):
            eval_scores.append(
                language_scores(caption_predictions[i],
                                self.opt.run_name,
                                self.result_path,
                                annFile=self.opt.coco_annotation_file))

        self.dmaker.train()

        return acc, eval_scores
    def validate(self):

        print("Validating")

        self.model.eval()
        loss, correct = 0.0, 0.0
        samples = []

        for step, sample in enumerate(self.val_loader):

            image, question, question_len, answer, captions, cap_lens, img_path, que_id = sample
            image, question, answer, captions = [
                x.to(device) for x in [image, question, answer, captions]
            ]

            # Forward pass
            result = self.model(image, question, captions)
            probs, logits = result.probs, result.logits

            loss += self.loss_function(logits, answer).item()

            # Compute top answer accuracy
            _, prediction_max_index = torch.max(probs, 1)
            _, answer_max_index = torch.max(answer, 1)

            correct += (
                answer_max_index == prediction_max_index).float().sum().item()

            a_i2w = self.val_loader.dataset.a_i2w
            q_i2w = self.val_loader.dataset.q_i2w
            c_i2w = self.val_loader.dataset.c_i2w

            # Append prediction to VQA2.0 validation
            for i in range(image.size(0)):
                samples.append({
                    'question_id': que_id[i].item(),
                    'answer': a_i2w[prediction_max_index[i].item()]
                })

            # write image and Q&A to html file to visualize training progress
            if step % self.opt.visualize_every == 0:

                # Show the top 3 predicted answers
                pros, ans = torch.topk(probs, k=3, dim=1)
                captions, cap_lens, question, question_len, pros, ans =\
                    [util.to_np(x) for x in [captions, cap_lens, question, question_len, pros, ans]]

                for i in range(image.size(0) / 2):

                    # Show question
                    que_arr = util.idx2str(q_i2w,
                                           (question[i])[:question_len[i]])

                    entry = {
                        'img':
                        img_path[i],
                        'epoch':
                        self.epoch,
                        'question':
                        ' '.join(que_arr),
                        'gt_ans':
                        a_i2w[answer_max_index[i].item()],
                        'predictions':
                        [[p, a_i2w[a]] for p, a in zip(pros[i], ans[i])],
                        'refs': [
                            ' '.join(
                                util.idx2str(c_i2w,
                                             (captions[i, j])[:cap_lens[i,
                                                                        j]]))
                            for j in range(3)
                        ]
                    }
                    self.visualizer.add_entry(entry)

        self.visualizer.update_html()
        # Reset
        self.model.train()

        return samples, [
            x / len(self.val_loader.dataset) for x in [loss, correct]
        ]
示例#4
0
    def save_collected_data(self, output_dir, logger):
        img_avg_rwd = self.get_average_reward_each_image()

        imgs_to_gather = int(len(self.imgId2idx) * self.H)
        gt_reward_idx = min(
            int(len(img_avg_rwd) - (1 - self.lamda) * imgs_to_gather),
            len(img_avg_rwd) - 1)
        gt_reward = img_avg_rwd[gt_reward_idx][1]

        new_data = []
        imgs_collected, captions_collected, skipped_pos_parse_error = 0, 0, 0
        q_gathers, no_q_gathers, total_cider = 0, 0, 0.0
        gaveup_list = []

        img_avg_rwd.reverse()

        # collect the best caption for the top H% images
        for imgId, _ in img_avg_rwd:

            idxs = self.imgId2idx[imgId]

            if imgs_collected < imgs_to_gather:

                for idx in idxs:
                    item = self.get_sample_best_caption(idx)

                    caption = item['caption']
                    cap_str = idx2str(self.c_i2w, caption)
                    pos_str = [
                        p for w, p in self.stanfordnlp.pos_tag(' '.join(
                            [clean_str(x) for x in cap_str]))
                    ]

                    cap_clean, pos_clean = self.clean_cap_pos(cap_str, pos_str)

                    item['caption'] = cap_clean
                    caption = item['caption']

                    if len(pos_clean) == len(caption):

                        if idx in self.question_set:

                            q_gathers += 1
                            item['caption_type'] = 1
                        elif idx in self.no_question_set:

                            no_q_gathers += 1
                            item['caption_type'] = 2

                        item['pos'] = [
                            self.p_w2i[x]
                            if x in self.p_w2i else self.p_w2i['unknown']
                            for x in pos_clean
                        ]

                        new_data.append(item)

                        captions_collected += 1
                        total_cider += item['weight']

                    else:

                        gaveup_list.append(idx)
                        skipped_pos_parse_error += 1

            else:
                for idx in idxs:
                    gaveup_list.append(idx)

            imgs_collected += 1

        # get the ground truth captions for the rest of the images
        for idx in gaveup_list:
            item = self.collected_data[idx]['gt_data']
            item['weight'] = gt_reward
            item['caption_type'] = 0
            new_data.append(item)

        assert len(new_data) == len(self.data)

        # Logging
        logger.info("Collecting the top {}/{} images.".format(
            imgs_to_gather, len(self.imgId2idx)))
        logger.info(
            "Collected {}/{} captions, and asked {}/{} samples for ground truth."
            .format(captions_collected, len(self.data), len(gaveup_list),
                    len(self.data)))
        logger.info(
            "Of the {} captions collected, {} were from asking a question and {} were from not."
            .format(captions_collected, q_gathers, no_q_gathers))
        logger.info("{} captions were skipped due to pos parsing error".format(
            skipped_pos_parse_error))
        logger.info(
            "The average gathered reward was {} and the 80 percentile reward (GT reward is set to this) is {}."
            .format(total_cider / captions_collected, gt_reward))

        # Save data
        pickle.dump(
            {
                "data": new_data,
                "c_dicts": [self.c_i2w, self.c_w2i],
                "pos_dicts": [self.p_i2w, self.p_w2i],
                "a_dicts": [self.a_i2w, self.a_w2i],
                "q_dicts": [self.q_i2w, self.q_w2i],
                "special_symbols": self.special_symbols,
                "gt_caps_reward": gt_reward
            },
            open(output_dir, 'wb'),
            protocol=pickle.HIGHEST_PROTOCOL)

        stats = {
            k: v
            for k, v in zip([
                'collects', 'num gt', 'q collects', 'no q collects',
                'avg collect reward', 'gt reward', 'pos parse error'
            ], [
                captions_collected,
                len(gaveup_list), q_gathers, no_q_gathers, total_cider /
                captions_collected, gt_reward, skipped_pos_parse_error
            ])
        }
        self.set_gather_stats(stats)
示例#5
0
# get rollout caption
set_seed(SEED)
r = captioner.sample_with_teacher_answer(
    image, ask_mask.unsqueeze(0), answer_mask,
    torch.zeros([1, 1, 512], dtype=torch.float, device=device),
    torch.ones([1], dtype=torch.long, device=device), 17, True)
rollout, rollout_mask = r.caption, r.cap_mask
rollout_len = rollout_mask.long().sum(dim=1)

# get replace caption
replace = caption.clone()
replace[0, ask_idx] = answer

caption = caption[0, :cap_len].cpu().numpy()
rollout = rollout[0, :rollout_len].cpu().numpy()
replace = replace[0, :cap_len].cpu().numpy()
question = question[0, :q_len].cpu().numpy()
ask_idx = ask_idx.item()

caption = ' '.join(idx2str(ci2w, caption))
rollout = ' '.join(idx2str(ci2w, rollout))
replace = ' '.join(idx2str(ci2w, replace))
question = ' '.join(idx2str(qi2w, question))

print(caption)
print(rollout)
print(replace)
print(question)
print(ask_idx)
示例#6
0
    def validate_captioner(self):

        self.captioner.eval()
        word_loss, word_cor, pos_cor, pos_loss = 0.0, 0.0, 0.0, 0.0

        with torch.no_grad():
            for step, sample in enumerate(self.val_loader):
                image, source, target, caption_len, refs, ref_lens, pos, img_path = sample[:
                                                                                           -2]
                sample = [
                    x.to(device)
                    for x in [image, source, target, caption_len, pos]
                ]
                r = validate_helper(sample, self.captioner,
                                    self.val_loader.dataset.max_c_len)

                word_loss, word_cor, pos_cor, pos_loss = [
                    x + y for x, y in zip(
                        [word_loss, word_cor, pos_cor, pos_loss], r)
                ]

                # write image and predicted captions to html file to visualize training progress
                if step % self.opt.val_print_every == 0:
                    beam_size = 3
                    c_i2w = self.val_loader.dataset.c_i2w
                    p_i2w = self.val_loader.dataset.p_i2w
                    refs, ref_lens = [util.to_np(x) for x in [refs, ref_lens]]

                    # show top 3 beam search captions
                    result = self.captioner.sample_beam(sample[0],
                                                        beam_size=beam_size,
                                                        max_seq_len=17)
                    captions, lps, lens = [
                        util.to_np(x) for x in [
                            result.captions,
                            torch.sum(result.logprobs, dim=2),
                            torch.sum(result.logprobs.abs() > 0.00001, dim=2)
                        ]
                    ]

                    # show greedy caption
                    result = self.captioner.sample(sample[0],
                                                   greedy=True,
                                                   max_seq_len=17)
                    pprob, pidx = torch.max(result.pos_prob, dim=2)
                    gcap, glp, glens, pidx, pprob = [
                        util.to_np(x) for x in [
                            result.caption,
                            torch.sum(result.log_prob, dim=1),
                            torch.sum(result.mask, dim=1), pidx, pprob
                        ]
                    ]

                    for i in range(image.size(0)):
                        cap_arr = util.idx2str(c_i2w, (gcap[i])[:glens[i]])
                        pos_arr = util.idx2pos(p_i2w, (pidx[i])[:glens[i]])
                        pos_pred = [
                            "{} ({} {:.2f})".format(cap_arr[j], pos_arr[j],
                                                    pprob[i, j])
                            for j in range(glens[i])
                        ]

                        entry = {
                            'img':
                            img_path[i],
                            'epoch':
                            self.cap_epoch,
                            'greedy_sample':
                            ' '.join(cap_arr) + " logprob: {}".format(glp[i]),
                            'pos_pred':
                            ' '.join(pos_pred),
                            'beamsearch': [
                                ' '.join(
                                    util.idx2str(c_i2w,
                                                 (captions[i, j])[:lens[i,
                                                                        j]])) +
                                " logprob: {}".format(lps[i, j])
                                for j in range(beam_size)
                            ],
                            'refs': [
                                ' '.join(
                                    util.idx2str(c_i2w,
                                                 (refs[i, j])[:ref_lens[i,
                                                                        j]]))
                                for j in range(3)
                            ]
                        }

                        self.Cvisualizer.add_entry(entry)

        self.Cvisualizer.update_html()
        # Reset
        self.captioner.train()

        return [
            x / len(self.val_loader.dataset)
            for x in [word_loss, word_cor, pos_cor, pos_loss]
        ]
    def validate(self):

        loss, correct, correct_answers = 0.0, 0.0, 0
        self.model.eval()

        for step, batch in enumerate(self.val_loader):
            img_path = batch[-1]
            batch = [x.to(device) for x in batch[:-1]]

            image, question_len, source, target, caption, q_idx_vec, pos, att, context, refs, answer = self.compute_cap_features_val(
                batch)

            logits = self.model(image, caption, pos, context, att, source,
                                q_idx_vec)

            batch_loss = masked_CE(logits, target, question_len)
            loss += batch_loss.item()
            predictions = seq_max_and_mask(
                logits, question_len, self.val_loader.dataset.max_q_len + 1)

            correct += torch.sum((target == predictions).float() /
                                 question_len.float().unsqueeze(1)).item()

            # evaluate using VQA expert
            result = self.model.sample(
                image,
                caption,
                pos,
                context,
                att,
                q_idx_vec,
                greedy=True,
                max_seq_len=self.val_loader.dataset.max_q_len + 1)
            sample, log_probs, mask = result.question, result.log_prob, result.mask

            correct_answers += self.query_vqa(image, sample, refs, answer,
                                              mask)

            # write image and Q&A to html file to visualize training progress
            if step % self.opt.visualize_every == 0:
                beam_size = 3
                c_i2w = self.val_loader.dataset.c_i2w
                a_i2w = self.val_loader.dataset.a_i2w
                q_i2w = self.val_loader.dataset.q_i2w

                sample_len = mask.sum(dim=1)

                _, _, beam_predictions, beam_lps = self.model.sample_beam(
                    image,
                    caption,
                    pos,
                    context,
                    att,
                    q_idx_vec,
                    beam_size=beam_size,
                    max_seq_len=15)
                beam_predictions, lps, lens = [
                    util.to_np(x) for x in [
                        beam_predictions,
                        torch.sum(beam_lps, dim=2),
                        torch.sum(beam_lps != 0, dim=2)
                    ]
                ]

                target, question_len, sample, sample_len = [
                    util.to_np(x)
                    for x in [target, question_len, sample, sample_len]
                ]

                for i in range(image.size(0)):

                    entry = {
                        'img':
                        img_path[i],
                        'epoch':
                        self.epoch,
                        'answer':
                        a_i2w[answer[i].item()],
                        'gt_question':
                        ' '.join(
                            util.idx2str(q_i2w,
                                         (target[i])[:question_len[i] - 1])),
                        'greedy_question':
                        ' '.join(
                            util.idx2str(q_i2w, (sample[i])[:sample_len[i]])),
                        'beamsearch': [
                            ' '.join(
                                util.idx2str(
                                    q_i2w,
                                    (beam_predictions[i, j])[:lens[i, j]])) +
                            " logprob: {}".format(lps[i, j])
                            for j in range(beam_size)
                        ]
                    }
                    self.visualizer.add_entry(entry)

        self.visualizer.update_html()

        # Reset
        self.model.train()

        l = len(self.val_loader.dataset)
        return [loss / l, correct / l, 100.0 * correct_answers / l]