def eval_coco(model, loader, run_name, log_dir, config): print("Computing coco-caption scores") # run coco-caption eval model.eval() with torch.no_grad(): # compute predictions for each image predicted_captions = [] for i, (image, image_id) in enumerate(loader): if config.beam_size < 2: # sample result = model.sample(image.to(device), greedy=config.eval_greedy, max_seq_len=loader.dataset.max_c_len + 1) caps, cap_lens = to_np(result.caption), to_np( torch.sum(result.mask, dim=1)) else: # beam search result = model.sample_beam( image.to(device), beam_size=config.beam_size, max_seq_len=loader.dataset.max_c_len + 1) # subtract 1 to exclude <eos> token which is returned by beam search caps, cap_lens = to_np(result.best_caption), to_np( torch.sum(result.best_logprobs != 0, dim=1) - 1) # append captions to list for j in range(image.size(0)): caption = ' '.join([ loader.dataset.c_i2w[x] for x in caps[j][:cap_lens[j]] ]).strip() pred = {'image_id': image_id[j].item(), 'caption': caption} predicted_captions.append(pred) # Reset model.train() return language_scores(predicted_captions, run_name, log_dir, annFile=config.coco_annotation_file)
def build_cap_ref_dicts(pred, pred_len, targ, targ_len, eos_symbol, c_i2w): pred, pred_len, ref, ref_len = [ to_np(x).tolist() for x in [pred, pred_len, targ, targ_len] ] gts = OrderedDict() for i in range(len(ref)): all_refs = [] for j in range(len(ref[i])): all_refs.append( symbol_to_string(ref[i][j], ref_len[i][j], eos_symbol, c_i2w)) gts[i] = all_refs res = [{ 'image_id': i, 'caption': [symbol_to_string(pred[i], pred_len[i], eos_symbol, c_i2w)] } for i in range(len(pred))] gts = {i: gts[i] for i in range(len(gts))} return gts, res
def collect_batch(self, index, captions, masks, rewards, ans=None, qs=None, qlens=None, topk=None): index, captions, lens = to_np(index), [to_np(x) for x in captions], [ to_np(torch.sum(x, dim=1)) for x in masks ] if self.model == "QuestionAskingTrainer": ans, qs, qlens, topk = [to_np(x) for x in ans], [to_np(x) for x in qs],\ [to_np(x) for x in qlens], [to_np(x) for x in topk] else: ans, qs, qlens, topk = [], [], [], [] self.data_dump.append( [index, captions, lens, rewards, ans, qs, qlens, topk])
def evaluate_with_questions(self): self.std_logger.info("Validating decision maker") self.dmaker.eval() self.captioner.train() self.qgen.train() # if self.opt.cap_eval: # self.captioner.eval() # else: # self.captioner.train() # # if self.opt.quegen_eval: # self.question_generator.eval() # else: # self.question_generator.train() c_i2w = self.val_loader.dataset.c_i2w correct, pos_correct, eval_scores = [0.0, 0.0], [0.0, 0.0], [] caption_predictions = [[], []] with torch.no_grad(): for step, sample in enumerate(self.val_loader): image, refs, target, caption_len = [ x.to(device) for x in [sample[0], sample[4], sample[2], sample[3]] ] ref_lens, img_path, index, img_id = sample[5], sample[ 7], sample[8], sample[9] batch_size = image.size(0) caps, cmasks, poses, pps, qs, qlps, qmasks, aps, cps, atts = [], [], [], [], [], [],[], [], [], [] # 1. Caption completely self.set_seed(self.opt.seed) r = self.captioner.sample( image, greedy=True, max_seq_len=self.opt.c_max_sentence_len + 1) caption, cap_probs, cap_mask, pos_probs, att, topk_words, attended_img \ = r.caption, r.prob, r.mask, r.pos_prob, r.attention.squeeze(), r.topk, r.atdimg cap_len = cap_mask.long().sum(dim=1) caps.append(caption) cmasks.append(cap_mask) poses.append(pos_probs) caption = self.pad_caption(caption, cap_len) # get the hidden state context source = torch.cat([ self.ones_vector[:batch_size].unsqueeze(1), caption[:, :-1] ], dim=1) r = self.fixed_caption_encoder(image, source, gt_pos=None, ss=False) h = r.hidden # 2. Identify the best time to ask a question, excluding ended sentences logit, valid_pos_mask = self.dmaker( h, attended_img, caption, cap_len, pos_probs, topk_words, self.captioner.caption_embedding.weight.data) masked_prob = masked_softmax( logit, cap_mask, valid_pos_mask, max_len=self.opt.c_max_sentence_len) dm_prob, ask_idx, ask_flag, ask_mask = self.sample_decision( masked_prob, cap_mask, greedy=True) # 3. Ask the teacher a question and get the answer ans, ans_mask, r = self.ask_question(image, caption, refs, pos_probs, h, att, ask_idx, q_greedy=True) # logging cps.append(cap_probs[self.range_vector[:batch_size], ask_idx]) cps.append(cap_probs[self.range_vector[:batch_size], ask_idx]) pps.append(r.pos_prob[0]) pps.append(r.pos_prob[0]) atts.append(r.att[0]) atts.append(r.att[0]) qlps.append(r.q_logprob.unsqueeze(1)) qlps.append(r.q_logprob.unsqueeze(1)) qmasks.append(r.q_mask) qmasks.append(r.q_mask) qs.append(r.question) qs.append(r.question) aps.append(r.ans_prob) aps.append(r.ans_prob) # 4. Compute new captions based on teacher's answer # rollout caption r = self.caption_with_teacher_answer(image, ask_mask, ans_mask, greedy=True) poses.append(r.pos_prob) rollout = r.caption rollout_mask = r.cap_mask # replace caption replace = replace_word_in_caption(caps[0], ans, ask_idx, ask_flag) base_rwd = mixed_reward(caps[0], torch.sum(cmasks[0], dim=1), refs, ref_lens, self.scorers, self.c_i2w) rollout_rwd = mixed_reward(rollout, torch.sum(rollout_mask, dim=1), refs, ref_lens, self.scorers, self.c_i2w) replace_rwd = mixed_reward(replace, torch.sum(cmasks[0], dim=1), refs, ref_lens, self.scorers, self.c_i2w) caps.append(rollout) caps.append(replace) cmasks.append(rollout_mask) cmasks.append(cmasks[0]) stat_rero, stat_rore, stat_reall, stat_roall = get_rollout_replace_stats( replace_rwd, rollout_rwd, base_rwd) best_cap, best_cap_mask = choose_better_caption( replace_rwd, replace, cmasks[0], rollout_rwd, rollout, rollout_mask) caps = [caps[0], best_cap] cmasks = [cmasks[0], best_cap_mask] # Collect captions for coco evaluation img_id = util.to_np(img_id) for i in range(len(caps)): words, lens, = util.to_np(caps[i]), util.to_np( cmasks[i].sum(dim=1)) for j in range(image.size(0)): inds = words[j][:lens[j]] caption = "" for k, ind in enumerate(inds): if k > 0: caption = caption + ' ' caption = caption + c_i2w[ind] pred = {'image_id': img_id[j], 'caption': caption} caption_predictions[i].append(pred) for i in range(len(correct)): predictions = caps[i] * cmasks[i].long() correct[i] += ( (target == predictions).float() / caption_len.float().unsqueeze(1)).sum().item() # Logging if step % self.opt.val_print_every == 0: c_i2w = self.val_loader.dataset.c_i2w p_i2w = self.val_loader.dataset.p_i2w q_i2w = self.val_loader.dataset.q_i2w a_i2w = self.val_loader.dataset.a_i2w caption, rollout, replace, cap_len, rollout_len, dec_probs, question, q_logprob, q_len, \ flag, raw_idx, refs, ref_lens = \ [util.to_np(x) for x in [caps[0], rollout, replace, cmasks[0].long().sum(dim=1), rollout_mask.long().sum(dim=1), masked_prob, qs[0], qlps[0], qmasks[0].long().sum(dim=1), ask_flag.long(), ask_idx, refs, ref_lens]] pos_probs, ans_probs, cap_probs = pps[0], aps[0], cps[0] for i in range(image.size(0)): top_pos = torch.topk(pos_probs, 3)[1] top_ans = torch.topk(ans_probs[i], 3)[1] top_cap = torch.topk(cap_probs[i], 5)[1] word = c_i2w[caption[i][ raw_idx[i]]] if raw_idx[i] < len( caption[i]) else 'Nan' entry = { 'img': img_path[i], 'epoch': self.collection_epoch, 'caption': ' '.join( util.idx2str(c_i2w, (caption[i])[:cap_len[i]])) + " | Reward: {:.2f}".format(base_rwd[i]), 'qaskprobs': ' '.join([ "{:.2f}".format(x) if x > 0.0 else "_" for x in dec_probs[i] ]), 'rollout_caption': ' '.join( util.idx2str(c_i2w, (rollout[i])[:rollout_len[i]])) + " | Reward: {:.2f}".format(rollout_rwd[i]), 'replace_caption': ' '.join( util.idx2str(c_i2w, (replace[i])[:cap_len[i]])) + " | Reward: {:.2f}".format(replace_rwd[i]), 'index': raw_idx[i], 'flag': bool(flag[i]), 'word': word, 'pos': ' | '.join([ p_i2w[x.item()] if x.item() in p_i2w else p_i2w[18] for x in top_pos ]), 'question': ' '.join( util.idx2str(q_i2w, (question[i])[:q_len[i]])) + " | logprob: {}".format( q_logprob[i, :q_len[i]].sum()), 'answers': ' | '.join([a_i2w[x.item()] for x in top_ans]), 'words': ' | '.join([c_i2w[x.item()] for x in top_cap]), 'refs': [ ' '.join( util.idx2str(c_i2w, (refs[i, j])[:ref_lens[i, j]])) for j in range(3) ] } self.valLLvisualizer.add_entry(entry) info = { 'eval/replace over rollout': float(stat_rero) / (batch_size), 'eval/rollout over replace': float(stat_rore) / (batch_size), 'eval/replace over all': float(stat_reall) / (batch_size), 'eval/rollout over all': float(stat_roall) / (batch_size), 'eval/question asking frequency (percent)': float(ask_flag.sum().item()) / ask_flag.size(0) } util.step_logging(self.logger, info, self.eval_steps) self.eval_steps += 1 self.valLLvisualizer.update_html() acc = [x / len(self.val_loader.dataset) for x in correct] for i in range(len(correct)): eval_scores.append( language_scores(caption_predictions[i], self.opt.run_name, self.result_path, annFile=self.opt.coco_annotation_file)) self.dmaker.train() return acc, eval_scores
def do_iteration(self, image, refs, ref_lens, index, img_path): self.d_optimizer.zero_grad() batch_size = image.size(0) caps, cmasks, qs, qlps, qmasks, aps, pps, atts, cps, dps = [], [], [], [], [], [], [], [], [], [] # 1. Caption completely self.set_seed(self.opt.seed) r = self.captioner.sample(image, greedy=self.opt.cap_greedy, max_seq_len=self.opt.c_max_sentence_len + 1, temperature=self.opt.temperature) caption, cap_probs, cap_mask, pos_probs, att, topk_words, attended_img \ = r.caption, r.prob, r.mask, r.pos_prob, r.attention.squeeze(), r.topk, r.atdimg # Don't backprop through captioner caption, cap_probs, cap_mask, pos_probs, attended_img \ = [x.detach() for x in [caption, cap_probs, cap_mask, pos_probs, attended_img]] cap_len = cap_mask.long().sum(dim=1) caps.append(caption) cmasks.append(cap_mask) caption = self.pad_caption(caption, cap_len) # get the hidden state context source = torch.cat( [self.ones_vector[:batch_size].unsqueeze(1), caption[:, :-1]], dim=1) r = self.fixed_caption_encoder(image, source, gt_pos=None, ss=False) h = r.hidden.detach() topk_words = [[y.detach() for y in x] for x in topk_words] # 2. Identify the best time to ask a question, excluding ended sentences, baseline against the greedy decision logit, valid_pos_mask = self.dmaker( h, attended_img, caption, cap_len, pos_probs, topk_words, self.captioner.caption_embedding.weight.data) masked_prob = masked_softmax(logit, cap_mask, valid_pos_mask, self.opt.dm_temperature, max_len=self.opt.c_max_sentence_len) dm_prob, ask_idx, ask_flag, ask_mask = self.sample_decision( masked_prob, cap_mask, greedy=False) _, ask_idx_greedy, ask_flag_greedy, ask_mask_greedy = self.sample_decision( masked_prob, cap_mask, greedy=True) dps.append(dm_prob.unsqueeze(1)) # 3. Ask the teacher a question and get the answer ans, ans_mask, r = self.ask_question(image, caption, refs, pos_probs, h, att, ask_idx, q_greedy=self.opt.q_greedy, temperature=self.opt.temperature) ans_greedy, ans_mask_greedy, rg = self.ask_question( image, caption, refs, pos_probs, h, att, ask_idx_greedy, q_greedy=self.opt.q_greedy, temperature=self.opt.temperature) # logging stuff cps.append(cap_probs[self.range_vector[:batch_size], ask_idx]) cps.append(cap_probs[self.range_vector[:batch_size], ask_idx_greedy]) pps.append(r.pos_prob[0]) pps.append(rg.pos_prob[0]) atts.append(r.att[0]) atts.append(rg.att[0]) qlps.append(r.q_logprob.unsqueeze(1)) qlps.append(rg.q_logprob.unsqueeze(1)) qmasks.append(r.q_mask) qmasks.append(rg.q_mask) qs.append(r.question.detach()) qs.append(rg.question.detach()) aps.append(r.ans_prob.detach()) aps.append(rg.ans_prob.detach()) # 4. Compute new captions based on teacher's answer # rollout caption r = self.caption_with_teacher_answer(image, ask_mask, ans_mask, greedy=self.opt.cap_greedy, temperature=self.opt.temperature) rg = self.caption_with_teacher_answer(image, ask_mask_greedy, ans_mask_greedy, greedy=self.opt.cap_greedy, temperature=self.opt.temperature) rollout, rollout_mask, rollout_greedy, rollout_mask_greedy = [ x.detach() for x in [r.caption, r.cap_mask, rg.caption, rg.cap_mask] ] # replace caption replace = replace_word_in_caption(caps[0], ans, ask_idx, ask_flag) replace_greedy = replace_word_in_caption(caps[0], ans_greedy, ask_idx_greedy, ask_flag_greedy) # 5. Compute reward for captions base_rwd = mixed_reward(caps[0], torch.sum(cmasks[0], dim=1), refs, ref_lens, self.scorers, self.c_i2w) rollout_rwd = mixed_reward(rollout, torch.sum(rollout_mask, dim=1), refs, ref_lens, self.scorers, self.c_i2w) rollout_greedy_rwd = mixed_reward( rollout_greedy, torch.sum(rollout_mask_greedy, dim=1), refs, ref_lens, self.scorers, self.c_i2w) replace_rwd = mixed_reward(replace, torch.sum(cmasks[0], dim=1), refs, ref_lens, self.scorers, self.c_i2w) replace_greedy_rwd = mixed_reward(replace_greedy, torch.sum(cmasks[0], dim=1), refs, ref_lens, self.scorers, self.c_i2w) rwd = np.maximum(replace_rwd, rollout_rwd) rwd_greedy = np.maximum(replace_greedy_rwd, rollout_greedy_rwd) best_cap, best_cap_mask = choose_better_caption( replace_rwd, replace, cmasks[0], rollout_rwd, rollout, rollout_mask) best_cap_greedy, best_cap_greedy_mask = choose_better_caption( replace_greedy_rwd, replace_greedy, cmasks[0], rollout_greedy_rwd, rollout_greedy, rollout_mask_greedy) # some statistics on whether rollout or single-word-replace is better stat_rero, stat_rore, stat_reall, stat_roall = get_rollout_replace_stats( replace_rwd, rollout_rwd, base_rwd) caps.append(best_cap) cmasks.append(best_cap_mask) caps.append(best_cap_greedy) cmasks.append(best_cap_greedy_mask) # Backwards pass to train decision maker with policy gradient loss reward_delta = torch.from_numpy(rwd - rwd_greedy).type( torch.float).to(device) reward_delta = reward_delta - self.opt.ask_penalty * ask_flag.float() loss = masked_PG(reward_delta.detach(), torch.log(dps[0]).squeeze(), ask_flag.detach()) loss.backward() self.d_optimizer.step() # also save the question asked and answer, and top-k predictions from captioner answers = [torch.max(x, dim=1)[1] for x in aps] topwords = [torch.topk(x, 20)[1] for x in cps] question_lens = [x.sum(dim=1) for x in qmasks] self.data_collector.collect_batch(index.clone(), caps, cmasks, [base_rwd, rwd, rwd_greedy], answers, qs, question_lens, topwords) # Logging if self.collection_steps % self.opt.print_every == 0: c_i2w = self.val_loader.dataset.c_i2w p_i2w = self.val_loader.dataset.p_i2w q_i2w = self.val_loader.dataset.q_i2w a_i2w = self.val_loader.dataset.a_i2w caption, rollout, replace, cap_len, rollout_len, dec_probs, question, q_logprob, q_len,\ flag, raw_idx, refs, ref_lens = \ [util.to_np(x) for x in [caps[0], rollout, replace, cmasks[0].long().sum(dim=1), rollout_mask.long().sum(dim=1), masked_prob, qs[0], qlps[0], qmasks[0].long().sum(dim=1), ask_flag.long(), ask_idx, refs, ref_lens]] pos_probs, ans_probs, cap_probs = pps[0], aps[0], cps[0] for i in range(image.size(0)): top_pos = torch.topk(pos_probs, 3)[1] top_ans = torch.topk(ans_probs[i], 3)[1] top_cap = torch.topk(cap_probs[i], 5)[1] word = c_i2w[caption[i][raw_idx[i]]] if raw_idx[i] < len( caption[i]) else 'Nan' entry = { 'img': img_path[i], 'epoch': self.collection_epoch, 'caption': ' '.join(util.idx2str(c_i2w, (caption[i])[:cap_len[i]])) + " | Reward: {:.2f}".format(base_rwd[i]), 'qaskprobs': ' '.join([ "{:.2f}".format(x) if x > 0.0 else "_" for x in dec_probs[i] ]), 'rollout_caption': ' '.join(util.idx2str(c_i2w, (rollout[i])[:rollout_len[i]])) + " | Reward: {:.2f}".format(rollout_rwd[i]), 'replace_caption': ' '.join(util.idx2str(c_i2w, (replace[i])[:cap_len[i]])) + " | Reward: {:.2f}".format(replace_rwd[i]), 'index': raw_idx[i], 'flag': bool(flag[i]), 'word': word, 'pos': ' | '.join([ p_i2w[x.item()] if x.item() in p_i2w else p_i2w[18] for x in top_pos ]), 'question': ' '.join(util.idx2str(q_i2w, (question[i])[:q_len[i]])) + " | logprob: {}".format(q_logprob[i, :q_len[i]].sum()), 'answers': ' | '.join([a_i2w[x.item()] for x in top_ans]), 'words': ' | '.join([c_i2w[x.item()] for x in top_cap]), 'refs': [ ' '.join( util.idx2str(c_i2w, (refs[i, j])[:ref_lens[i, j]])) for j in range(3) ] } self.trainLLvisualizer.add_entry(entry) info = { 'collect/question logprob': torch.mean(torch.cat(qlps, dim=1)).item(), 'collect/replace over rollout': float(stat_rero) / (batch_size), 'collect/rollout over replace': float(stat_rore) / (batch_size), 'collect/replace over all': float(stat_reall) / (batch_size), 'collect/rollout over all': float(stat_roall) / (batch_size), 'collect/sampled decision equals greedy decision': float((ask_idx == ask_idx_greedy).sum().item()) / (batch_size), 'collect/question asking frequency (percent)': float(ask_flag.sum().item()) / ask_flag.size(0) } util.step_logging(self.logger, info, self.collection_steps) return loss.item()
def validate(self): print("Validating") self.model.eval() loss, correct = 0.0, 0.0 samples = [] for step, sample in enumerate(self.val_loader): image, question, question_len, answer, captions, cap_lens, img_path, que_id = sample image, question, answer, captions = [ x.to(device) for x in [image, question, answer, captions] ] # Forward pass result = self.model(image, question, captions) probs, logits = result.probs, result.logits loss += self.loss_function(logits, answer).item() # Compute top answer accuracy _, prediction_max_index = torch.max(probs, 1) _, answer_max_index = torch.max(answer, 1) correct += ( answer_max_index == prediction_max_index).float().sum().item() a_i2w = self.val_loader.dataset.a_i2w q_i2w = self.val_loader.dataset.q_i2w c_i2w = self.val_loader.dataset.c_i2w # Append prediction to VQA2.0 validation for i in range(image.size(0)): samples.append({ 'question_id': que_id[i].item(), 'answer': a_i2w[prediction_max_index[i].item()] }) # write image and Q&A to html file to visualize training progress if step % self.opt.visualize_every == 0: # Show the top 3 predicted answers pros, ans = torch.topk(probs, k=3, dim=1) captions, cap_lens, question, question_len, pros, ans =\ [util.to_np(x) for x in [captions, cap_lens, question, question_len, pros, ans]] for i in range(image.size(0) / 2): # Show question que_arr = util.idx2str(q_i2w, (question[i])[:question_len[i]]) entry = { 'img': img_path[i], 'epoch': self.epoch, 'question': ' '.join(que_arr), 'gt_ans': a_i2w[answer_max_index[i].item()], 'predictions': [[p, a_i2w[a]] for p, a in zip(pros[i], ans[i])], 'refs': [ ' '.join( util.idx2str(c_i2w, (captions[i, j])[:cap_lens[i, j]])) for j in range(3) ] } self.visualizer.add_entry(entry) self.visualizer.update_html() # Reset self.model.train() return samples, [ x / len(self.val_loader.dataset) for x in [loss, correct] ]
def validate_captioner(self): self.captioner.eval() word_loss, word_cor, pos_cor, pos_loss = 0.0, 0.0, 0.0, 0.0 with torch.no_grad(): for step, sample in enumerate(self.val_loader): image, source, target, caption_len, refs, ref_lens, pos, img_path = sample[: -2] sample = [ x.to(device) for x in [image, source, target, caption_len, pos] ] r = validate_helper(sample, self.captioner, self.val_loader.dataset.max_c_len) word_loss, word_cor, pos_cor, pos_loss = [ x + y for x, y in zip( [word_loss, word_cor, pos_cor, pos_loss], r) ] # write image and predicted captions to html file to visualize training progress if step % self.opt.val_print_every == 0: beam_size = 3 c_i2w = self.val_loader.dataset.c_i2w p_i2w = self.val_loader.dataset.p_i2w refs, ref_lens = [util.to_np(x) for x in [refs, ref_lens]] # show top 3 beam search captions result = self.captioner.sample_beam(sample[0], beam_size=beam_size, max_seq_len=17) captions, lps, lens = [ util.to_np(x) for x in [ result.captions, torch.sum(result.logprobs, dim=2), torch.sum(result.logprobs.abs() > 0.00001, dim=2) ] ] # show greedy caption result = self.captioner.sample(sample[0], greedy=True, max_seq_len=17) pprob, pidx = torch.max(result.pos_prob, dim=2) gcap, glp, glens, pidx, pprob = [ util.to_np(x) for x in [ result.caption, torch.sum(result.log_prob, dim=1), torch.sum(result.mask, dim=1), pidx, pprob ] ] for i in range(image.size(0)): cap_arr = util.idx2str(c_i2w, (gcap[i])[:glens[i]]) pos_arr = util.idx2pos(p_i2w, (pidx[i])[:glens[i]]) pos_pred = [ "{} ({} {:.2f})".format(cap_arr[j], pos_arr[j], pprob[i, j]) for j in range(glens[i]) ] entry = { 'img': img_path[i], 'epoch': self.cap_epoch, 'greedy_sample': ' '.join(cap_arr) + " logprob: {}".format(glp[i]), 'pos_pred': ' '.join(pos_pred), 'beamsearch': [ ' '.join( util.idx2str(c_i2w, (captions[i, j])[:lens[i, j]])) + " logprob: {}".format(lps[i, j]) for j in range(beam_size) ], 'refs': [ ' '.join( util.idx2str(c_i2w, (refs[i, j])[:ref_lens[i, j]])) for j in range(3) ] } self.Cvisualizer.add_entry(entry) self.Cvisualizer.update_html() # Reset self.captioner.train() return [ x / len(self.val_loader.dataset) for x in [word_loss, word_cor, pos_cor, pos_loss] ]
def validate(self): loss, correct, correct_answers = 0.0, 0.0, 0 self.model.eval() for step, batch in enumerate(self.val_loader): img_path = batch[-1] batch = [x.to(device) for x in batch[:-1]] image, question_len, source, target, caption, q_idx_vec, pos, att, context, refs, answer = self.compute_cap_features_val( batch) logits = self.model(image, caption, pos, context, att, source, q_idx_vec) batch_loss = masked_CE(logits, target, question_len) loss += batch_loss.item() predictions = seq_max_and_mask( logits, question_len, self.val_loader.dataset.max_q_len + 1) correct += torch.sum((target == predictions).float() / question_len.float().unsqueeze(1)).item() # evaluate using VQA expert result = self.model.sample( image, caption, pos, context, att, q_idx_vec, greedy=True, max_seq_len=self.val_loader.dataset.max_q_len + 1) sample, log_probs, mask = result.question, result.log_prob, result.mask correct_answers += self.query_vqa(image, sample, refs, answer, mask) # write image and Q&A to html file to visualize training progress if step % self.opt.visualize_every == 0: beam_size = 3 c_i2w = self.val_loader.dataset.c_i2w a_i2w = self.val_loader.dataset.a_i2w q_i2w = self.val_loader.dataset.q_i2w sample_len = mask.sum(dim=1) _, _, beam_predictions, beam_lps = self.model.sample_beam( image, caption, pos, context, att, q_idx_vec, beam_size=beam_size, max_seq_len=15) beam_predictions, lps, lens = [ util.to_np(x) for x in [ beam_predictions, torch.sum(beam_lps, dim=2), torch.sum(beam_lps != 0, dim=2) ] ] target, question_len, sample, sample_len = [ util.to_np(x) for x in [target, question_len, sample, sample_len] ] for i in range(image.size(0)): entry = { 'img': img_path[i], 'epoch': self.epoch, 'answer': a_i2w[answer[i].item()], 'gt_question': ' '.join( util.idx2str(q_i2w, (target[i])[:question_len[i] - 1])), 'greedy_question': ' '.join( util.idx2str(q_i2w, (sample[i])[:sample_len[i]])), 'beamsearch': [ ' '.join( util.idx2str( q_i2w, (beam_predictions[i, j])[:lens[i, j]])) + " logprob: {}".format(lps[i, j]) for j in range(beam_size) ] } self.visualizer.add_entry(entry) self.visualizer.update_html() # Reset self.model.train() l = len(self.val_loader.dataset) return [loss / l, correct / l, 100.0 * correct_answers / l]