def do_iteration(self, image, refs, ref_lens, index, img_path): self.d_optimizer.zero_grad() batch_size = image.size(0) caps, cmasks, qs, qlps, qmasks, aps, pps, atts, cps, dps = [], [], [], [], [], [], [], [], [], [] # 1. Caption completely self.set_seed(self.opt.seed) r = self.captioner.sample(image, greedy=self.opt.cap_greedy, max_seq_len=self.opt.c_max_sentence_len + 1, temperature=self.opt.temperature) caption, cap_probs, cap_mask, pos_probs, att, topk_words, attended_img \ = r.caption, r.prob, r.mask, r.pos_prob, r.attention.squeeze(), r.topk, r.atdimg # Don't backprop through captioner caption, cap_probs, cap_mask, pos_probs, attended_img \ = [x.detach() for x in [caption, cap_probs, cap_mask, pos_probs, attended_img]] cap_len = cap_mask.long().sum(dim=1) caps.append(caption) cmasks.append(cap_mask) caption = self.pad_caption(caption, cap_len) # get the hidden state context source = torch.cat( [self.ones_vector[:batch_size].unsqueeze(1), caption[:, :-1]], dim=1) r = self.fixed_caption_encoder(image, source, gt_pos=None, ss=False) h = r.hidden.detach() topk_words = [[y.detach() for y in x] for x in topk_words] # 2. Identify the best time to ask a question, excluding ended sentences, baseline against the greedy decision logit, valid_pos_mask = self.dmaker( h, attended_img, caption, cap_len, pos_probs, topk_words, self.captioner.caption_embedding.weight.data) masked_prob = masked_softmax(logit, cap_mask, valid_pos_mask, self.opt.dm_temperature, max_len=self.opt.c_max_sentence_len) dm_prob, ask_idx, ask_flag, ask_mask = self.sample_decision( masked_prob, cap_mask, greedy=False) _, ask_idx_greedy, ask_flag_greedy, ask_mask_greedy = self.sample_decision( masked_prob, cap_mask, greedy=True) dps.append(dm_prob.unsqueeze(1)) # 3. Ask the teacher a question and get the answer ans, ans_mask, r = self.ask_question(image, caption, refs, pos_probs, h, att, ask_idx, q_greedy=self.opt.q_greedy, temperature=self.opt.temperature) ans_greedy, ans_mask_greedy, rg = self.ask_question( image, caption, refs, pos_probs, h, att, ask_idx_greedy, q_greedy=self.opt.q_greedy, temperature=self.opt.temperature) # logging stuff cps.append(cap_probs[self.range_vector[:batch_size], ask_idx]) cps.append(cap_probs[self.range_vector[:batch_size], ask_idx_greedy]) pps.append(r.pos_prob[0]) pps.append(rg.pos_prob[0]) atts.append(r.att[0]) atts.append(rg.att[0]) qlps.append(r.q_logprob.unsqueeze(1)) qlps.append(rg.q_logprob.unsqueeze(1)) qmasks.append(r.q_mask) qmasks.append(rg.q_mask) qs.append(r.question.detach()) qs.append(rg.question.detach()) aps.append(r.ans_prob.detach()) aps.append(rg.ans_prob.detach()) # 4. Compute new captions based on teacher's answer # rollout caption r = self.caption_with_teacher_answer(image, ask_mask, ans_mask, greedy=self.opt.cap_greedy, temperature=self.opt.temperature) rg = self.caption_with_teacher_answer(image, ask_mask_greedy, ans_mask_greedy, greedy=self.opt.cap_greedy, temperature=self.opt.temperature) rollout, rollout_mask, rollout_greedy, rollout_mask_greedy = [ x.detach() for x in [r.caption, r.cap_mask, rg.caption, rg.cap_mask] ] # replace caption replace = replace_word_in_caption(caps[0], ans, ask_idx, ask_flag) replace_greedy = replace_word_in_caption(caps[0], ans_greedy, ask_idx_greedy, ask_flag_greedy) # 5. Compute reward for captions base_rwd = mixed_reward(caps[0], torch.sum(cmasks[0], dim=1), refs, ref_lens, self.scorers, self.c_i2w) rollout_rwd = mixed_reward(rollout, torch.sum(rollout_mask, dim=1), refs, ref_lens, self.scorers, self.c_i2w) rollout_greedy_rwd = mixed_reward( rollout_greedy, torch.sum(rollout_mask_greedy, dim=1), refs, ref_lens, self.scorers, self.c_i2w) replace_rwd = mixed_reward(replace, torch.sum(cmasks[0], dim=1), refs, ref_lens, self.scorers, self.c_i2w) replace_greedy_rwd = mixed_reward(replace_greedy, torch.sum(cmasks[0], dim=1), refs, ref_lens, self.scorers, self.c_i2w) rwd = np.maximum(replace_rwd, rollout_rwd) rwd_greedy = np.maximum(replace_greedy_rwd, rollout_greedy_rwd) best_cap, best_cap_mask = choose_better_caption( replace_rwd, replace, cmasks[0], rollout_rwd, rollout, rollout_mask) best_cap_greedy, best_cap_greedy_mask = choose_better_caption( replace_greedy_rwd, replace_greedy, cmasks[0], rollout_greedy_rwd, rollout_greedy, rollout_mask_greedy) # some statistics on whether rollout or single-word-replace is better stat_rero, stat_rore, stat_reall, stat_roall = get_rollout_replace_stats( replace_rwd, rollout_rwd, base_rwd) caps.append(best_cap) cmasks.append(best_cap_mask) caps.append(best_cap_greedy) cmasks.append(best_cap_greedy_mask) # Backwards pass to train decision maker with policy gradient loss reward_delta = torch.from_numpy(rwd - rwd_greedy).type( torch.float).to(device) reward_delta = reward_delta - self.opt.ask_penalty * ask_flag.float() loss = masked_PG(reward_delta.detach(), torch.log(dps[0]).squeeze(), ask_flag.detach()) loss.backward() self.d_optimizer.step() # also save the question asked and answer, and top-k predictions from captioner answers = [torch.max(x, dim=1)[1] for x in aps] topwords = [torch.topk(x, 20)[1] for x in cps] question_lens = [x.sum(dim=1) for x in qmasks] self.data_collector.collect_batch(index.clone(), caps, cmasks, [base_rwd, rwd, rwd_greedy], answers, qs, question_lens, topwords) # Logging if self.collection_steps % self.opt.print_every == 0: c_i2w = self.val_loader.dataset.c_i2w p_i2w = self.val_loader.dataset.p_i2w q_i2w = self.val_loader.dataset.q_i2w a_i2w = self.val_loader.dataset.a_i2w caption, rollout, replace, cap_len, rollout_len, dec_probs, question, q_logprob, q_len,\ flag, raw_idx, refs, ref_lens = \ [util.to_np(x) for x in [caps[0], rollout, replace, cmasks[0].long().sum(dim=1), rollout_mask.long().sum(dim=1), masked_prob, qs[0], qlps[0], qmasks[0].long().sum(dim=1), ask_flag.long(), ask_idx, refs, ref_lens]] pos_probs, ans_probs, cap_probs = pps[0], aps[0], cps[0] for i in range(image.size(0)): top_pos = torch.topk(pos_probs, 3)[1] top_ans = torch.topk(ans_probs[i], 3)[1] top_cap = torch.topk(cap_probs[i], 5)[1] word = c_i2w[caption[i][raw_idx[i]]] if raw_idx[i] < len( caption[i]) else 'Nan' entry = { 'img': img_path[i], 'epoch': self.collection_epoch, 'caption': ' '.join(util.idx2str(c_i2w, (caption[i])[:cap_len[i]])) + " | Reward: {:.2f}".format(base_rwd[i]), 'qaskprobs': ' '.join([ "{:.2f}".format(x) if x > 0.0 else "_" for x in dec_probs[i] ]), 'rollout_caption': ' '.join(util.idx2str(c_i2w, (rollout[i])[:rollout_len[i]])) + " | Reward: {:.2f}".format(rollout_rwd[i]), 'replace_caption': ' '.join(util.idx2str(c_i2w, (replace[i])[:cap_len[i]])) + " | Reward: {:.2f}".format(replace_rwd[i]), 'index': raw_idx[i], 'flag': bool(flag[i]), 'word': word, 'pos': ' | '.join([ p_i2w[x.item()] if x.item() in p_i2w else p_i2w[18] for x in top_pos ]), 'question': ' '.join(util.idx2str(q_i2w, (question[i])[:q_len[i]])) + " | logprob: {}".format(q_logprob[i, :q_len[i]].sum()), 'answers': ' | '.join([a_i2w[x.item()] for x in top_ans]), 'words': ' | '.join([c_i2w[x.item()] for x in top_cap]), 'refs': [ ' '.join( util.idx2str(c_i2w, (refs[i, j])[:ref_lens[i, j]])) for j in range(3) ] } self.trainLLvisualizer.add_entry(entry) info = { 'collect/question logprob': torch.mean(torch.cat(qlps, dim=1)).item(), 'collect/replace over rollout': float(stat_rero) / (batch_size), 'collect/rollout over replace': float(stat_rore) / (batch_size), 'collect/replace over all': float(stat_reall) / (batch_size), 'collect/rollout over all': float(stat_roall) / (batch_size), 'collect/sampled decision equals greedy decision': float((ask_idx == ask_idx_greedy).sum().item()) / (batch_size), 'collect/question asking frequency (percent)': float(ask_flag.sum().item()) / ask_flag.size(0) } util.step_logging(self.logger, info, self.collection_steps) return loss.item()
def evaluate_with_questions(self): self.std_logger.info("Validating decision maker") self.dmaker.eval() self.captioner.train() self.qgen.train() # if self.opt.cap_eval: # self.captioner.eval() # else: # self.captioner.train() # # if self.opt.quegen_eval: # self.question_generator.eval() # else: # self.question_generator.train() c_i2w = self.val_loader.dataset.c_i2w correct, pos_correct, eval_scores = [0.0, 0.0], [0.0, 0.0], [] caption_predictions = [[], []] with torch.no_grad(): for step, sample in enumerate(self.val_loader): image, refs, target, caption_len = [ x.to(device) for x in [sample[0], sample[4], sample[2], sample[3]] ] ref_lens, img_path, index, img_id = sample[5], sample[ 7], sample[8], sample[9] batch_size = image.size(0) caps, cmasks, poses, pps, qs, qlps, qmasks, aps, cps, atts = [], [], [], [], [], [],[], [], [], [] # 1. Caption completely self.set_seed(self.opt.seed) r = self.captioner.sample( image, greedy=True, max_seq_len=self.opt.c_max_sentence_len + 1) caption, cap_probs, cap_mask, pos_probs, att, topk_words, attended_img \ = r.caption, r.prob, r.mask, r.pos_prob, r.attention.squeeze(), r.topk, r.atdimg cap_len = cap_mask.long().sum(dim=1) caps.append(caption) cmasks.append(cap_mask) poses.append(pos_probs) caption = self.pad_caption(caption, cap_len) # get the hidden state context source = torch.cat([ self.ones_vector[:batch_size].unsqueeze(1), caption[:, :-1] ], dim=1) r = self.fixed_caption_encoder(image, source, gt_pos=None, ss=False) h = r.hidden # 2. Identify the best time to ask a question, excluding ended sentences logit, valid_pos_mask = self.dmaker( h, attended_img, caption, cap_len, pos_probs, topk_words, self.captioner.caption_embedding.weight.data) masked_prob = masked_softmax( logit, cap_mask, valid_pos_mask, max_len=self.opt.c_max_sentence_len) dm_prob, ask_idx, ask_flag, ask_mask = self.sample_decision( masked_prob, cap_mask, greedy=True) # 3. Ask the teacher a question and get the answer ans, ans_mask, r = self.ask_question(image, caption, refs, pos_probs, h, att, ask_idx, q_greedy=True) # logging cps.append(cap_probs[self.range_vector[:batch_size], ask_idx]) cps.append(cap_probs[self.range_vector[:batch_size], ask_idx]) pps.append(r.pos_prob[0]) pps.append(r.pos_prob[0]) atts.append(r.att[0]) atts.append(r.att[0]) qlps.append(r.q_logprob.unsqueeze(1)) qlps.append(r.q_logprob.unsqueeze(1)) qmasks.append(r.q_mask) qmasks.append(r.q_mask) qs.append(r.question) qs.append(r.question) aps.append(r.ans_prob) aps.append(r.ans_prob) # 4. Compute new captions based on teacher's answer # rollout caption r = self.caption_with_teacher_answer(image, ask_mask, ans_mask, greedy=True) poses.append(r.pos_prob) rollout = r.caption rollout_mask = r.cap_mask # replace caption replace = replace_word_in_caption(caps[0], ans, ask_idx, ask_flag) base_rwd = mixed_reward(caps[0], torch.sum(cmasks[0], dim=1), refs, ref_lens, self.scorers, self.c_i2w) rollout_rwd = mixed_reward(rollout, torch.sum(rollout_mask, dim=1), refs, ref_lens, self.scorers, self.c_i2w) replace_rwd = mixed_reward(replace, torch.sum(cmasks[0], dim=1), refs, ref_lens, self.scorers, self.c_i2w) caps.append(rollout) caps.append(replace) cmasks.append(rollout_mask) cmasks.append(cmasks[0]) stat_rero, stat_rore, stat_reall, stat_roall = get_rollout_replace_stats( replace_rwd, rollout_rwd, base_rwd) best_cap, best_cap_mask = choose_better_caption( replace_rwd, replace, cmasks[0], rollout_rwd, rollout, rollout_mask) caps = [caps[0], best_cap] cmasks = [cmasks[0], best_cap_mask] # Collect captions for coco evaluation img_id = util.to_np(img_id) for i in range(len(caps)): words, lens, = util.to_np(caps[i]), util.to_np( cmasks[i].sum(dim=1)) for j in range(image.size(0)): inds = words[j][:lens[j]] caption = "" for k, ind in enumerate(inds): if k > 0: caption = caption + ' ' caption = caption + c_i2w[ind] pred = {'image_id': img_id[j], 'caption': caption} caption_predictions[i].append(pred) for i in range(len(correct)): predictions = caps[i] * cmasks[i].long() correct[i] += ( (target == predictions).float() / caption_len.float().unsqueeze(1)).sum().item() # Logging if step % self.opt.val_print_every == 0: c_i2w = self.val_loader.dataset.c_i2w p_i2w = self.val_loader.dataset.p_i2w q_i2w = self.val_loader.dataset.q_i2w a_i2w = self.val_loader.dataset.a_i2w caption, rollout, replace, cap_len, rollout_len, dec_probs, question, q_logprob, q_len, \ flag, raw_idx, refs, ref_lens = \ [util.to_np(x) for x in [caps[0], rollout, replace, cmasks[0].long().sum(dim=1), rollout_mask.long().sum(dim=1), masked_prob, qs[0], qlps[0], qmasks[0].long().sum(dim=1), ask_flag.long(), ask_idx, refs, ref_lens]] pos_probs, ans_probs, cap_probs = pps[0], aps[0], cps[0] for i in range(image.size(0)): top_pos = torch.topk(pos_probs, 3)[1] top_ans = torch.topk(ans_probs[i], 3)[1] top_cap = torch.topk(cap_probs[i], 5)[1] word = c_i2w[caption[i][ raw_idx[i]]] if raw_idx[i] < len( caption[i]) else 'Nan' entry = { 'img': img_path[i], 'epoch': self.collection_epoch, 'caption': ' '.join( util.idx2str(c_i2w, (caption[i])[:cap_len[i]])) + " | Reward: {:.2f}".format(base_rwd[i]), 'qaskprobs': ' '.join([ "{:.2f}".format(x) if x > 0.0 else "_" for x in dec_probs[i] ]), 'rollout_caption': ' '.join( util.idx2str(c_i2w, (rollout[i])[:rollout_len[i]])) + " | Reward: {:.2f}".format(rollout_rwd[i]), 'replace_caption': ' '.join( util.idx2str(c_i2w, (replace[i])[:cap_len[i]])) + " | Reward: {:.2f}".format(replace_rwd[i]), 'index': raw_idx[i], 'flag': bool(flag[i]), 'word': word, 'pos': ' | '.join([ p_i2w[x.item()] if x.item() in p_i2w else p_i2w[18] for x in top_pos ]), 'question': ' '.join( util.idx2str(q_i2w, (question[i])[:q_len[i]])) + " | logprob: {}".format( q_logprob[i, :q_len[i]].sum()), 'answers': ' | '.join([a_i2w[x.item()] for x in top_ans]), 'words': ' | '.join([c_i2w[x.item()] for x in top_cap]), 'refs': [ ' '.join( util.idx2str(c_i2w, (refs[i, j])[:ref_lens[i, j]])) for j in range(3) ] } self.valLLvisualizer.add_entry(entry) info = { 'eval/replace over rollout': float(stat_rero) / (batch_size), 'eval/rollout over replace': float(stat_rore) / (batch_size), 'eval/replace over all': float(stat_reall) / (batch_size), 'eval/rollout over all': float(stat_roall) / (batch_size), 'eval/question asking frequency (percent)': float(ask_flag.sum().item()) / ask_flag.size(0) } util.step_logging(self.logger, info, self.eval_steps) self.eval_steps += 1 self.valLLvisualizer.update_html() acc = [x / len(self.val_loader.dataset) for x in correct] for i in range(len(correct)): eval_scores.append( language_scores(caption_predictions[i], self.opt.run_name, self.result_path, annFile=self.opt.coco_annotation_file)) self.dmaker.train() return acc, eval_scores
def validate(self): print("Validating") self.model.eval() loss, correct = 0.0, 0.0 samples = [] for step, sample in enumerate(self.val_loader): image, question, question_len, answer, captions, cap_lens, img_path, que_id = sample image, question, answer, captions = [ x.to(device) for x in [image, question, answer, captions] ] # Forward pass result = self.model(image, question, captions) probs, logits = result.probs, result.logits loss += self.loss_function(logits, answer).item() # Compute top answer accuracy _, prediction_max_index = torch.max(probs, 1) _, answer_max_index = torch.max(answer, 1) correct += ( answer_max_index == prediction_max_index).float().sum().item() a_i2w = self.val_loader.dataset.a_i2w q_i2w = self.val_loader.dataset.q_i2w c_i2w = self.val_loader.dataset.c_i2w # Append prediction to VQA2.0 validation for i in range(image.size(0)): samples.append({ 'question_id': que_id[i].item(), 'answer': a_i2w[prediction_max_index[i].item()] }) # write image and Q&A to html file to visualize training progress if step % self.opt.visualize_every == 0: # Show the top 3 predicted answers pros, ans = torch.topk(probs, k=3, dim=1) captions, cap_lens, question, question_len, pros, ans =\ [util.to_np(x) for x in [captions, cap_lens, question, question_len, pros, ans]] for i in range(image.size(0) / 2): # Show question que_arr = util.idx2str(q_i2w, (question[i])[:question_len[i]]) entry = { 'img': img_path[i], 'epoch': self.epoch, 'question': ' '.join(que_arr), 'gt_ans': a_i2w[answer_max_index[i].item()], 'predictions': [[p, a_i2w[a]] for p, a in zip(pros[i], ans[i])], 'refs': [ ' '.join( util.idx2str(c_i2w, (captions[i, j])[:cap_lens[i, j]])) for j in range(3) ] } self.visualizer.add_entry(entry) self.visualizer.update_html() # Reset self.model.train() return samples, [ x / len(self.val_loader.dataset) for x in [loss, correct] ]
def save_collected_data(self, output_dir, logger): img_avg_rwd = self.get_average_reward_each_image() imgs_to_gather = int(len(self.imgId2idx) * self.H) gt_reward_idx = min( int(len(img_avg_rwd) - (1 - self.lamda) * imgs_to_gather), len(img_avg_rwd) - 1) gt_reward = img_avg_rwd[gt_reward_idx][1] new_data = [] imgs_collected, captions_collected, skipped_pos_parse_error = 0, 0, 0 q_gathers, no_q_gathers, total_cider = 0, 0, 0.0 gaveup_list = [] img_avg_rwd.reverse() # collect the best caption for the top H% images for imgId, _ in img_avg_rwd: idxs = self.imgId2idx[imgId] if imgs_collected < imgs_to_gather: for idx in idxs: item = self.get_sample_best_caption(idx) caption = item['caption'] cap_str = idx2str(self.c_i2w, caption) pos_str = [ p for w, p in self.stanfordnlp.pos_tag(' '.join( [clean_str(x) for x in cap_str])) ] cap_clean, pos_clean = self.clean_cap_pos(cap_str, pos_str) item['caption'] = cap_clean caption = item['caption'] if len(pos_clean) == len(caption): if idx in self.question_set: q_gathers += 1 item['caption_type'] = 1 elif idx in self.no_question_set: no_q_gathers += 1 item['caption_type'] = 2 item['pos'] = [ self.p_w2i[x] if x in self.p_w2i else self.p_w2i['unknown'] for x in pos_clean ] new_data.append(item) captions_collected += 1 total_cider += item['weight'] else: gaveup_list.append(idx) skipped_pos_parse_error += 1 else: for idx in idxs: gaveup_list.append(idx) imgs_collected += 1 # get the ground truth captions for the rest of the images for idx in gaveup_list: item = self.collected_data[idx]['gt_data'] item['weight'] = gt_reward item['caption_type'] = 0 new_data.append(item) assert len(new_data) == len(self.data) # Logging logger.info("Collecting the top {}/{} images.".format( imgs_to_gather, len(self.imgId2idx))) logger.info( "Collected {}/{} captions, and asked {}/{} samples for ground truth." .format(captions_collected, len(self.data), len(gaveup_list), len(self.data))) logger.info( "Of the {} captions collected, {} were from asking a question and {} were from not." .format(captions_collected, q_gathers, no_q_gathers)) logger.info("{} captions were skipped due to pos parsing error".format( skipped_pos_parse_error)) logger.info( "The average gathered reward was {} and the 80 percentile reward (GT reward is set to this) is {}." .format(total_cider / captions_collected, gt_reward)) # Save data pickle.dump( { "data": new_data, "c_dicts": [self.c_i2w, self.c_w2i], "pos_dicts": [self.p_i2w, self.p_w2i], "a_dicts": [self.a_i2w, self.a_w2i], "q_dicts": [self.q_i2w, self.q_w2i], "special_symbols": self.special_symbols, "gt_caps_reward": gt_reward }, open(output_dir, 'wb'), protocol=pickle.HIGHEST_PROTOCOL) stats = { k: v for k, v in zip([ 'collects', 'num gt', 'q collects', 'no q collects', 'avg collect reward', 'gt reward', 'pos parse error' ], [ captions_collected, len(gaveup_list), q_gathers, no_q_gathers, total_cider / captions_collected, gt_reward, skipped_pos_parse_error ]) } self.set_gather_stats(stats)
# get rollout caption set_seed(SEED) r = captioner.sample_with_teacher_answer( image, ask_mask.unsqueeze(0), answer_mask, torch.zeros([1, 1, 512], dtype=torch.float, device=device), torch.ones([1], dtype=torch.long, device=device), 17, True) rollout, rollout_mask = r.caption, r.cap_mask rollout_len = rollout_mask.long().sum(dim=1) # get replace caption replace = caption.clone() replace[0, ask_idx] = answer caption = caption[0, :cap_len].cpu().numpy() rollout = rollout[0, :rollout_len].cpu().numpy() replace = replace[0, :cap_len].cpu().numpy() question = question[0, :q_len].cpu().numpy() ask_idx = ask_idx.item() caption = ' '.join(idx2str(ci2w, caption)) rollout = ' '.join(idx2str(ci2w, rollout)) replace = ' '.join(idx2str(ci2w, replace)) question = ' '.join(idx2str(qi2w, question)) print(caption) print(rollout) print(replace) print(question) print(ask_idx)
def validate_captioner(self): self.captioner.eval() word_loss, word_cor, pos_cor, pos_loss = 0.0, 0.0, 0.0, 0.0 with torch.no_grad(): for step, sample in enumerate(self.val_loader): image, source, target, caption_len, refs, ref_lens, pos, img_path = sample[: -2] sample = [ x.to(device) for x in [image, source, target, caption_len, pos] ] r = validate_helper(sample, self.captioner, self.val_loader.dataset.max_c_len) word_loss, word_cor, pos_cor, pos_loss = [ x + y for x, y in zip( [word_loss, word_cor, pos_cor, pos_loss], r) ] # write image and predicted captions to html file to visualize training progress if step % self.opt.val_print_every == 0: beam_size = 3 c_i2w = self.val_loader.dataset.c_i2w p_i2w = self.val_loader.dataset.p_i2w refs, ref_lens = [util.to_np(x) for x in [refs, ref_lens]] # show top 3 beam search captions result = self.captioner.sample_beam(sample[0], beam_size=beam_size, max_seq_len=17) captions, lps, lens = [ util.to_np(x) for x in [ result.captions, torch.sum(result.logprobs, dim=2), torch.sum(result.logprobs.abs() > 0.00001, dim=2) ] ] # show greedy caption result = self.captioner.sample(sample[0], greedy=True, max_seq_len=17) pprob, pidx = torch.max(result.pos_prob, dim=2) gcap, glp, glens, pidx, pprob = [ util.to_np(x) for x in [ result.caption, torch.sum(result.log_prob, dim=1), torch.sum(result.mask, dim=1), pidx, pprob ] ] for i in range(image.size(0)): cap_arr = util.idx2str(c_i2w, (gcap[i])[:glens[i]]) pos_arr = util.idx2pos(p_i2w, (pidx[i])[:glens[i]]) pos_pred = [ "{} ({} {:.2f})".format(cap_arr[j], pos_arr[j], pprob[i, j]) for j in range(glens[i]) ] entry = { 'img': img_path[i], 'epoch': self.cap_epoch, 'greedy_sample': ' '.join(cap_arr) + " logprob: {}".format(glp[i]), 'pos_pred': ' '.join(pos_pred), 'beamsearch': [ ' '.join( util.idx2str(c_i2w, (captions[i, j])[:lens[i, j]])) + " logprob: {}".format(lps[i, j]) for j in range(beam_size) ], 'refs': [ ' '.join( util.idx2str(c_i2w, (refs[i, j])[:ref_lens[i, j]])) for j in range(3) ] } self.Cvisualizer.add_entry(entry) self.Cvisualizer.update_html() # Reset self.captioner.train() return [ x / len(self.val_loader.dataset) for x in [word_loss, word_cor, pos_cor, pos_loss] ]
def validate(self): loss, correct, correct_answers = 0.0, 0.0, 0 self.model.eval() for step, batch in enumerate(self.val_loader): img_path = batch[-1] batch = [x.to(device) for x in batch[:-1]] image, question_len, source, target, caption, q_idx_vec, pos, att, context, refs, answer = self.compute_cap_features_val( batch) logits = self.model(image, caption, pos, context, att, source, q_idx_vec) batch_loss = masked_CE(logits, target, question_len) loss += batch_loss.item() predictions = seq_max_and_mask( logits, question_len, self.val_loader.dataset.max_q_len + 1) correct += torch.sum((target == predictions).float() / question_len.float().unsqueeze(1)).item() # evaluate using VQA expert result = self.model.sample( image, caption, pos, context, att, q_idx_vec, greedy=True, max_seq_len=self.val_loader.dataset.max_q_len + 1) sample, log_probs, mask = result.question, result.log_prob, result.mask correct_answers += self.query_vqa(image, sample, refs, answer, mask) # write image and Q&A to html file to visualize training progress if step % self.opt.visualize_every == 0: beam_size = 3 c_i2w = self.val_loader.dataset.c_i2w a_i2w = self.val_loader.dataset.a_i2w q_i2w = self.val_loader.dataset.q_i2w sample_len = mask.sum(dim=1) _, _, beam_predictions, beam_lps = self.model.sample_beam( image, caption, pos, context, att, q_idx_vec, beam_size=beam_size, max_seq_len=15) beam_predictions, lps, lens = [ util.to_np(x) for x in [ beam_predictions, torch.sum(beam_lps, dim=2), torch.sum(beam_lps != 0, dim=2) ] ] target, question_len, sample, sample_len = [ util.to_np(x) for x in [target, question_len, sample, sample_len] ] for i in range(image.size(0)): entry = { 'img': img_path[i], 'epoch': self.epoch, 'answer': a_i2w[answer[i].item()], 'gt_question': ' '.join( util.idx2str(q_i2w, (target[i])[:question_len[i] - 1])), 'greedy_question': ' '.join( util.idx2str(q_i2w, (sample[i])[:sample_len[i]])), 'beamsearch': [ ' '.join( util.idx2str( q_i2w, (beam_predictions[i, j])[:lens[i, j]])) + " logprob: {}".format(lps[i, j]) for j in range(beam_size) ] } self.visualizer.add_entry(entry) self.visualizer.update_html() # Reset self.model.train() l = len(self.val_loader.dataset) return [loss / l, correct / l, 100.0 * correct_answers / l]