def lm_trainstep(train_step, global_step, reader, model, sampler, sess): outputs = reader.pop_batch() images, quest, quest_len, ans, ans_len = outputs quest, quest_len = _Q_CTX.get_gt_batch(quest, quest_len) # random sample new # random sampling noise_vec, pathes, scores = sampler.random_sampling([images, ans, ans_len], sess) _this_batch_size = images.shape[0] scores, pathes, noise = post_process_variation_questions_noise( scores, pathes, noise_vec, _this_batch_size) # update language model # fake = wrap_samples_for_language_model(pathes) # real = [quest, quest_len] # lm_inputs = fake + real if 'CNN' in model.name: lm_inputs = wrap_samples_for_language_model( sampled=pathes, pad_token=sampler.pad_token - 1, gts=[quest, quest_len], max_length=20) else: lm_inputs = wrap_samples_for_language_model( sampled=pathes, pad_token=sampler.pad_token - 1, gts=[quest, quest_len]) loss, gstep = sess.run([train_step, global_step], feed_dict=model.fill_feed_dict(lm_inputs)) if gstep % 500 == 0: def _show_examples(arr, arr_len, name): _rewards = model.inference([arr, arr_len]) ps = _parse_gt_questions(arr, arr_len) print('\n%s:' % (name)) for p, r in zip(ps, _rewards): if p[-1] == 2: p = p[:-1] sent = _SENT.index_to_question(p) print('%s (%0.3f)' % (sent, r)) fake, fake_len, real, real_len = lm_inputs _show_examples(fake, fake_len, 'Fake') _show_examples(real, real_len, 'Real') return loss, gstep
def inference(self, ivqa_pathes): language_model_inputs = wrap_samples_for_language_model([ivqa_pathes], pad_token=self.pad_token - 1, max_length=20) match_gt = self.eg_lm.query(ivqa_pathes, self.min_gt_count) legality_scores = self.nn_lm.inference(language_model_inputs) legality_scores[match_gt] = 1.0 return legality_scores
def ivqa_decoding_beam_search(checkpoint_path=None): model_config = ModelConfig() method = FLAGS.method res_file = 'result/bs_gen_%s.json' % method score_file = 'result/bs_vqa_scores_%s.mat' % method # Get model model_fn = get_model_creation_fn('VAQ-Var') create_fn = create_reader('VAQ-VVIS', phase='test') # Create the vocabulary. to_sentence = SentenceGenerator(trainset='trainval') # get data reader subset = 'kptest' reader = create_fn(batch_size=1, subset=subset, version=FLAGS.test_version) exemplar = ExemplarLanguageModel() if checkpoint_path is None: if FLAGS.checkpoint_dir: ckpt_dir = FLAGS.checkpoint_dir else: ckpt_dir = FLAGS.checkpoint_pat % (FLAGS.version, FLAGS.model_type) # ckpt_dir = '/import/vision-ephemeral/fl302/models/v2_kpvaq_VAQ-RL/' ckpt = tf.train.get_checkpoint_state(ckpt_dir) checkpoint_path = ckpt.model_checkpoint_path # Build model g = tf.Graph() with g.as_default(): # Build the model.ex model = model_fn(model_config, 'sampling') model.set_num_sampling_points(1000) model.build() # Restore from checkpoint restorer = Restorer(g) sess = tf.Session() restorer.restore(sess, checkpoint_path) # build language model language_model = LanguageModel() language_model.build() language_model.set_cache_dir('test_empty') # language_model.set_cache_dir('v1_var_att_lowthresh_cache_restval_VAQ-VarRL') language_model.set_session(sess) language_model.setup_model() # build VQA model vqa_model = VQAWrapper(g, sess) # vqa_model = MLBWrapper() num_batches = reader.num_batches print('Running beam search inference...') results = [] batch_vqa_scores = [] num = FLAGS.max_iters if FLAGS.max_iters > 0 else num_batches for i in range(num): outputs = reader.get_test_batch() # inference quest_ids, image_ids = outputs[-2:] im, _, _, top_ans, ans_tokens, ans_len = outputs[:-2] # pdb.set_trace() if top_ans == 2000: continue print('\n%d/%d' % (i, num)) question_id = int(quest_ids[0]) image_id = int(image_ids[0]) t1 = time() pathes, scores = model.greedy_inference([im, ans_tokens, ans_len], sess) # find unique ivqa_scores, ivqa_pathes = process_one(scores, pathes) t2 = time() print('Time for sample generation: %0.2fs' % (t2 - t1)) # apply language model language_model_inputs = wrap_samples_for_language_model( [ivqa_pathes], pad_token=model.pad_token - 1, max_length=20) match_gt = exemplar.query(ivqa_pathes) legality_scores = language_model.inference(language_model_inputs) legality_scores[match_gt] = 1.0 num_keep = max(100, (legality_scores > 0.1).sum()) # no less than 100 valid_inds = (-legality_scores).argsort()[:num_keep] t3 = time() print('Time for language model filtration: %0.2fs' % (t3 - t2)) # for idx in valid_inds: # path = ivqa_pathes[idx] # sc = legality_scores[idx] # sentence = to_sentence.index_to_question(path) # # questions.append(sentence) # print('%s (%0.3f)' % (sentence, sc)) # apply VQA model sampled = [ivqa_pathes[_idx] for _idx in valid_inds] # vqa_scores = vqa_model.get_scores(sampled, image_id, top_ans) vqa_scores, is_valid = vqa_model.get_scores(sampled, im, top_ans) # conf_inds = (-vqa_scores).argsort()[:20] conf_inds = np.where(is_valid)[0] # pdb.set_trace() # conf_inds = (-vqa_scores).argsort()[:40] t4 = time() print('Time for VQA verification: %0.2fs' % (t4 - t3)) this_mean_vqa_score = vqa_scores[conf_inds].mean() print('sampled: %d, unique: %d, legal: %d, gt: %d, mean score %0.2f' % (pathes.shape[0], len(ivqa_pathes), num_keep, match_gt.sum(), this_mean_vqa_score)) batch_vqa_scores.append(this_mean_vqa_score) for _pid, idx in enumerate(conf_inds): path = sampled[idx] sc = vqa_scores[idx] sentence = to_sentence.index_to_question(path) aug_quest_id = question_id * 1000 + _pid res_i = { 'image_id': int(image_id), 'question_id': aug_quest_id, 'question': sentence, 'score': float(sc) } results.append(res_i) save_json(res_file, results) batch_vqa_scores = np.array(batch_vqa_scores, dtype=np.float32) mean_vqa_score = batch_vqa_scores.mean() from scipy.io import savemat savemat(score_file, { 'scores': batch_vqa_scores, 'mean_score': mean_vqa_score }) print('BS mean VQA score: %0.3f' % mean_vqa_score) return res_file, mean_vqa_score
def ivqa_decoding_beam_search(checkpoint_path=None): model_config = ModelConfig() method = FLAGS.method res_file = 'result/bs_cand_for_vis.json' # Get model model_fn = get_model_creation_fn('VAQ-Var') create_fn = create_reader('VAQ-VVIS', phase='test') # Create the vocabulary. to_sentence = SentenceGenerator(trainset='trainval', top_ans_file='../VQA-tensorflow/data/vqa_trainval_top2000_answers.txt') # get data reader subset = 'kpval' reader = create_fn(batch_size=1, subset=subset, version=FLAGS.test_version) exemplar = ExemplarLanguageModel() if checkpoint_path is None: if FLAGS.checkpoint_dir: ckpt_dir = FLAGS.checkpoint_dir else: ckpt_dir = FLAGS.checkpoint_pat % (FLAGS.version, FLAGS.model_type) # ckpt_dir = '/import/vision-ephemeral/fl302/models/v2_kpvaq_VAQ-RL/' ckpt = tf.train.get_checkpoint_state(ckpt_dir) checkpoint_path = ckpt.model_checkpoint_path # Build model g = tf.Graph() with g.as_default(): # Build the model.ex model = model_fn(model_config, 'sampling') model.set_num_sampling_points(5000) model.build() # Restore from checkpoint restorer = Restorer(g) sess = tf.Session() restorer.restore(sess, checkpoint_path) # build language model language_model = LanguageModel() language_model.build() language_model.set_cache_dir('test_empty') # language_model.set_cache_dir('v1_var_att_lowthresh_cache_restval_VAQ-VarRL') language_model.set_session(sess) language_model.setup_model() # build VQA model # vqa_model = N2MNWrapper() # vqa_model = MLBWrapper() num_batches = reader.num_batches quest_ids_to_vis = {5682052: 'bread', 965492: 'plane', 681282: 'station'} print('Running beam search inference...') results = [] batch_vqa_scores = [] num = FLAGS.max_iters if FLAGS.max_iters > 0 else num_batches for i in range(num): outputs = reader.get_test_batch() # inference quest_ids, image_ids = outputs[-2:] quest_id_key = int(quest_ids) if quest_id_key not in quest_ids_to_vis: continue # pdb.set_trace() im, gt_q, _, top_ans, ans_tokens, ans_len = outputs[:-2] # pdb.set_trace() if top_ans == 2000: continue print('\n%d/%d' % (i, num)) question_id = int(quest_ids[0]) image_id = int(image_ids[0]) t1 = time() pathes, scores = model.greedy_inference([im, ans_tokens, ans_len], sess) # find unique ivqa_scores, ivqa_pathes = process_one(scores, pathes) t2 = time() print('Time for sample generation: %0.2fs' % (t2 - t1)) # apply language model language_model_inputs = wrap_samples_for_language_model([ivqa_pathes], pad_token=model.pad_token - 1, max_length=20) match_gt = exemplar.query(ivqa_pathes) legality_scores = language_model.inference(language_model_inputs) legality_scores[match_gt] = 1.0 num_keep = max(100, (legality_scores > 0.1).sum()) # no less than 100 valid_inds = (-legality_scores).argsort()[:num_keep] print('keep: %d/%d' % (num_keep, len(ivqa_pathes))) t3 = time() print('Time for language model filtration: %0.2fs' % (t3 - t2)) def token_arr_to_list(arr): return arr.flatten().tolist() for _pid, idx in enumerate(valid_inds): path = ivqa_pathes[idx] # sc = vqa_scores[idx] sentence = to_sentence.index_to_question(path) aug_quest_id = question_id * 1000 + _pid res_i = {'image_id': int(image_id), 'aug_id': aug_quest_id, 'question_id': question_id, 'target': sentence, 'top_ans_id': int(top_ans), 'question': to_sentence.index_to_question(token_arr_to_list(gt_q)), 'answer': to_sentence.index_to_answer(token_arr_to_list(ans_tokens))} results.append(res_i) save_json(res_file, results) return None
def reinforce_trainstep(reader_outputs, model, env, sess, task_ops, _VQA_Belief): # reader_outputs = reader.pop_batch() # quest_ids, images, quest, quest_len, top_ans, ans, ans_len = reader_outputs # select the first image # idx = 0 # # def _reshape_array(v): # if type(v) == np.ndarray: # return v[np.newaxis, :] # else: # return np.reshape(v, (1,)) # # selected = [_reshape_array(v[idx]) for v in reader_outputs] res5c, images, quest, quest_len, top_ans, ans, ans_len, quest_ids, image_ids = reader_outputs # random sampling noise_vec, pathes, scores = model.random_sampling([images, ans, ans_len], sess) _this_batch_size = images.shape[0] scores, pathes, noise = post_process_variation_questions_noise( scores, pathes, noise_vec, _this_batch_size, find_unique=False) lm_inputs = wrap_samples_for_language_model(sampled=pathes, pad_token=model.pad_token - 1, gts=[quest, quest_len], max_length=20) def _show_examples(arr, arr_len, _rewards, name): ps = _parse_gt_questions(arr, arr_len) print('\n%s:' % (name)) for p, r in zip(ps, _rewards): if p[-1] == 2: p = p[:-1] sent = env.to_sentence.index_to_question(p) print('%s (%d)' % (sent, r)) # compute reward vqa_inputs = [images, res5c, ans, ans_len, top_ans] # lm_inputs = lm_inputs[:2] wrapped_sampled = lm_inputs[:2] rewards, rewards_all, is_gt, aug_data = env.get_reward( pathes, [quest, quest_len], [vqa_inputs, wrapped_sampled, scores, quest_ids]) max_path_arr, max_path_len, max_noise, max_rewards = \ prepare_reinforce_data(pathes, noise, rewards, pad_token=model.pad_token) vqa_scores = rewards_all[:, 0] language_scores = rewards_all[:, 2] # scores = vqa_scores * (language_scores > 0.5) scores = vqa_scores * (language_scores > env.language_thresh) new_pathes = _parse_gt_questions(max_path_arr, max_path_len) _VQA_Belief.insert(new_pathes, scores) # _show_examples(max_path_arr, max_path_len, is_gt, 'Sampled') # pdb.set_trace() aug_images, aug_ans, aug_ans_len, is_in_vocab = aug_data sess_in = [ aug_images, max_path_arr, max_path_len, aug_ans, aug_ans_len, max_noise, max_rewards, rewards_all ] sess_in = [_in[is_in_vocab] for _in in sess_in] # remove oov avg_reward = max_rewards.mean() # train op sess_outputs = sess.run(task_ops, feed_dict=model.fill_feed_dict(sess_in)) sess_outputs += [avg_reward, 'reward'] # update language model # print('Number GT: %d' % is_gt.sum()) # num_fake_in_batch = 80 - is_gt.sum() if False: # at least half is generated wrapped_gt = _Q_CTX.get_gt_batch(*lm_inputs[2:]) # random sample new corrected_inputs = correct_language_model_inputs( wrapped_sampled + wrapped_gt, is_gt) # num_fake = corrected_inputs[0].shape[0] # num_real = corrected_inputs[2].shape[0] # print('Num positive: %d, num negative %d' % (num_real, num_fake)) # _show_examples(corrected_inputs[0], corrected_inputs[1], np.zeros_like(corrected_inputs[1]), 'Fake') # _show_examples(corrected_inputs[2], corrected_inputs[3], np.zeros_like(corrected_inputs[3]), 'Real') # pdb.set_trace() if min(wrapped_sampled[1].size, wrapped_gt[1].size) > 0: env.lm.trainstep(corrected_inputs) # _VQA_Belief.vertify_vqa(env, vqa_inputs) return sess_outputs
def ivqa_decoding_beam_search(checkpoint_path=None): model_config = ModelConfig() method = FLAGS.method res_file = 'result/bs_gen_%s.json' % method score_file = 'result/bs_vqa_scores_%s.mat' % method # Get model model_fn = get_model_creation_fn('VAQ-Var') create_fn = create_reader('VAQ-VVIS', phase='test') # Create the vocabulary. to_sentence = SentenceGenerator(trainset='trainval') # get data reader subset = 'kptrain' reader = create_fn(batch_size=1, subset=subset, version=FLAGS.test_version) exemplar = ExemplarLanguageModel() if checkpoint_path is None: if FLAGS.checkpoint_dir: ckpt_dir = FLAGS.checkpoint_dir else: ckpt_dir = FLAGS.checkpoint_pat % (FLAGS.version, FLAGS.model_type) # ckpt_dir = '/import/vision-ephemeral/fl302/models/v2_kpvaq_VAQ-RL/' ckpt = tf.train.get_checkpoint_state(ckpt_dir) checkpoint_path = ckpt.model_checkpoint_path # Build model g = tf.Graph() with g.as_default(): # Build the model.ex model = model_fn(model_config, 'sampling') model.set_num_sampling_points(5) model.build() # Restore from checkpoint restorer = Restorer(g) sess = tf.Session() restorer.restore(sess, checkpoint_path) # build language model language_model = LanguageModel() language_model.build() language_model.set_cache_dir('test_empty') # language_model.set_cache_dir('v1_var_att_lowthresh_cache_restval_VAQ-VarRL') language_model.set_session(sess) language_model.setup_model() num_batches = reader.num_batches print('Running beam search inference...') num = FLAGS.max_iters if FLAGS.max_iters > 0 else num_batches neg_pathes = [] need_stop = False for i in range(num): outputs = reader.get_test_batch() # inference im, _, _, top_ans, ans_tokens, ans_len = outputs[:-2] if top_ans == 2000: continue print('\n%d/%d' % (i, num)) t1 = time() pathes, scores = model.greedy_inference([im, ans_tokens, ans_len], sess) # find unique ivqa_scores, ivqa_pathes = process_one(scores, pathes) t2 = time() print('Time for sample generation: %0.2fs' % (t2 - t1)) # apply language model language_model_inputs = wrap_samples_for_language_model( [ivqa_pathes], pad_token=model.pad_token - 1, max_length=20) match_gt = exemplar.query(ivqa_pathes) legality_scores = language_model.inference(language_model_inputs) legality_scores[match_gt] = 1.0 neg_inds = np.where(legality_scores < 0.2)[0] for idx in neg_inds: ser_neg = serialize_path(ivqa_pathes[idx][1:]) neg_pathes.append(ser_neg) if len(neg_pathes) > 100000: need_stop = True break # if len(neg_pathes) > 1000: # need_stop = True # break # print('Neg size: %d' % len(neg_pathes)) if need_stop: break sv_file = 'data/lm_init_neg_pathes.json' save_json(sv_file, neg_pathes)
def ivqa_decoding_beam_search(ckpt_dir, method): model_config = ModelConfig() inf_type = 'beam' assert (inf_type in ['beam', 'rand']) # method = FLAGS.method if inf_type == 'rand': res_file = 'result/bs_RL2_cands_LM_%s.json' % method else: res_file = 'result/bs_RL2_cands_LM_%s_BEAM.json' % method if os.path.exists(res_file): print('File %s already exist, skipped' % res_file) return # score_file = 'result/bs_vqa_scores_%s.mat' % method # Get model model_fn = get_model_creation_fn('VAQ-Var') create_fn = create_reader('VAQ-VVIS', phase='test') # Create the vocabulary. to_sentence = SentenceGenerator(trainset='trainval') # get data reader subset = 'bs_test' reader = create_fn(batch_size=1, subset=subset, version=FLAGS.test_version) exemplar = ExemplarLanguageModel() # if checkpoint_path is None: # if FLAGS.checkpoint_dir: # ckpt_dir = FLAGS.checkpoint_dir # else: # ckpt_dir = FLAGS.checkpoint_pat % (FLAGS.version, FLAGS.model_type) # ckpt_dir = '/import/vision-ephemeral/fl302/models/v2_kpvaq_VAQ-RL/' ckpt = tf.train.get_checkpoint_state(ckpt_dir) checkpoint_path = ckpt.model_checkpoint_path # Build model g = tf.Graph() with g.as_default(): # Build the model.ex if inf_type == 'rand': model = model_fn(model_config, 'sampling') model.set_num_sampling_points(1000) else: model = model_fn(model_config, 'sampling_beam') model.set_num_sampling_points(1000) model.build() # Restore from checkpoint restorer = Restorer(g) sess = tf.Session() restorer.restore(sess, checkpoint_path) # build language model language_model = LanguageModel() language_model.build() language_model.set_cache_dir('test_empty') # language_model.set_cache_dir('v1_var_att_lowthresh_cache_restval_VAQ-VarRL') language_model.set_session(sess) language_model.setup_model() # build VQA model # vqa_model = N2MNWrapper() # vqa_model = MLBWrapper() num_batches = reader.num_batches print('Running beam search inference...') results = {} # batch_vqa_scores = [] num = FLAGS.max_iters if FLAGS.max_iters > 0 else num_batches for i in range(num): outputs = reader.get_test_batch() # inference quest_ids, image_ids = outputs[-2:] im, _, _, top_ans, ans_tokens, ans_len = outputs[:-2] # pdb.set_trace() if top_ans == 2000: continue print('\n%d/%d' % (i, num)) question_id = int(quest_ids[0]) image_id = int(image_ids[0]) t1 = time() pathes, scores = model.greedy_inference([im, ans_tokens, ans_len], sess) # find unique ivqa_scores, ivqa_pathes = process_one(scores, pathes) t2 = time() print('Time for sample generation: %0.2fs' % (t2 - t1)) # apply language model language_model_inputs = wrap_samples_for_language_model([ivqa_pathes], pad_token=model.pad_token - 1, max_length=20) match_gt = exemplar.query(ivqa_pathes) legality_scores = language_model.inference(language_model_inputs) legality_scores[match_gt] = 1.0 num_keep = max(100, (legality_scores > 0.3).sum()) # no less than 100 valid_inds = (-legality_scores).argsort()[:num_keep] t3 = time() print('Time for language model filtration: %0.2fs' % (t3 - t2)) # for idx in valid_inds: # path = ivqa_pathes[idx] # sc = legality_scores[idx] # sentence = to_sentence.index_to_question(path) # # questions.append(sentence) # print('%s (%0.3f)' % (sentence, sc)) # apply VQA model sampled = [ivqa_pathes[_idx] for _idx in valid_inds] legality_scores = legality_scores[valid_inds] result_key = int(question_id) tmp = [] for idx, path in enumerate(sampled): # path = sampled[idx] sc = legality_scores[idx] sentence = to_sentence.index_to_question(path) # aug_quest_id = question_id * 1000 + _pid res_i = {'image_id': int(image_id), 'aug_id': idx, 'question_id': question_id, 'question': sentence, 'score': float(sc)} tmp.append(res_i) print('Number of unique questions: %d' % len(tmp)) results[result_key] = tmp save_json(res_file, results)
def reinforce_trainstep(reader, model, env, sess, task_ops): outputs = reader.pop_batch() images, quest, quest_len, ans, ans_len = outputs # random sampling noise_vec, pathes, scores = model.random_sampling([images, ans, ans_len], sess) _this_batch_size = images.shape[0] scores, pathes, noise = post_process_variation_questions_noise( scores, pathes, noise_vec, _this_batch_size, find_unique=False) # diverse_rewards = env.diversity_reward.get_reward(pathes, scores) # update language model # lm = env.lm # fake = wrap_samples_for_language_model(pathes) # real = [quest, quest_len] # lm_inputs = fake + real lm_inputs = wrap_samples_for_language_model(sampled=pathes, pad_token=model.pad_token - 1, gts=[quest, quest_len], max_length=20) def _show_examples(arr, arr_len, _rewards, name): ps = _parse_gt_questions(arr, arr_len) print('\n%s:' % (name)) for p, r in zip(ps, _rewards): if p[-1] == 2: p = p[:-1] sent = env.to_sentence.index_to_question(p) print('%s (%d)' % (sent, r)) # compute reward vqa_inputs = [images, ans, ans_len] # lm_inputs = lm_inputs[:2] wrapped_sampled = lm_inputs[:2] rewards, rewards_all, is_gt, aug_data = env.get_reward( pathes, [quest, quest_len], [vqa_inputs, wrapped_sampled, scores]) max_path_arr, max_path_len, max_noise, max_rewards = \ prepare_reinforce_data(pathes, noise, rewards, pad_token=model.pad_token) # _show_examples(max_path_arr, max_path_len, is_gt, 'Sampled') # pdb.set_trace() aug_images, aug_ans, aug_ans_len, is_in_vocab = aug_data sess_in = [ aug_images, max_path_arr, max_path_len, aug_ans, aug_ans_len, max_noise, max_rewards, rewards_all ] sess_in = [_in[is_in_vocab] for _in in sess_in] # remove oov avg_reward = max_rewards.mean() # train op sess_outputs = sess.run(task_ops, feed_dict=model.fill_feed_dict(sess_in)) sess_outputs += [avg_reward, 'reward'] # update language model # print('Number GT: %d' % is_gt.sum()) num_fake_in_batch = 80 - is_gt.sum() if num_fake_in_batch > 50 or True: # at least half is generated wrapped_gt = _Q_CTX.get_gt_batch(*lm_inputs[2:]) # random sample new corrected_inputs = correct_language_model_inputs( wrapped_sampled + wrapped_gt, is_gt) # num_fake = corrected_inputs[0].shape[0] # num_real = corrected_inputs[2].shape[0] # print('Num positive: %d, num negative %d' % (num_real, num_fake)) # _show_examples(corrected_inputs[0], corrected_inputs[1], np.zeros_like(corrected_inputs[1]), 'Fake') # _show_examples(corrected_inputs[2], corrected_inputs[3], np.zeros_like(corrected_inputs[3]), 'Real') # pdb.set_trace() if num_fake_in_batch > 0: env.lm.trainstep(corrected_inputs) return sess_outputs
def reinforce_trainstep(reader, model, env, sess, task_ops): outputs = reader.pop_batch() quest_ids, images, quest, quest_len, top_ans, ans, ans_len = outputs # random sampling noise_vec, pathes, scores = model.random_sampling([images, ans, ans_len], sess) _this_batch_size = images.shape[0] scores, pathes, noise = post_process_variation_questions_noise(scores, pathes, noise_vec, _this_batch_size, find_unique=False) lm_inputs = wrap_samples_for_language_model(sampled=pathes, pad_token=model.pad_token - 1, gts=[quest, quest_len], max_length=20) def _show_examples(arr, arr_len, _rewards, name): ps = _parse_gt_questions(arr, arr_len) print('\n%s:' % (name)) for p, r in zip(ps, _rewards): if p[-1] == 2: p = p[:-1] sent = env.to_sentence.index_to_question(p) print('%s (%d)' % (sent, r)) # compute reward vqa_inputs = [images, ans, ans_len, top_ans] # lm_inputs = lm_inputs[:2] wrapped_sampled = lm_inputs[:2] rewards, rewards_all, is_gt, aug_data = env.get_reward(pathes, [quest, quest_len], [vqa_inputs, wrapped_sampled, scores, quest_ids]) max_path_arr, max_path_len, max_noise, max_rewards = \ prepare_reinforce_data(pathes, noise, rewards, pad_token=model.pad_token) # _show_examples(max_path_arr, max_path_len, is_gt, 'Sampled') # pdb.set_trace() aug_images, aug_quest, aug_quest_len, aug_ans, aug_ans_len, aug_top_ans, is_in_vocab = aug_data sess_in = [aug_images, max_path_arr, max_path_len, aug_ans, aug_ans_len, max_noise, max_rewards, rewards_all] sess_in = [_in[is_in_vocab] for _in in sess_in] # remove oov avg_reward = max_rewards.mean() # train op sess_outputs = sess.run(task_ops, feed_dict=model.fill_feed_dict(sess_in)) sess_outputs += [avg_reward, 'reward'] # update VQA model aug_legal_mask, aug_vqa_labels, hard_target_mask = correct_vqa_labels(aug_top_ans, rewards_all, is_in_vocab) gt_legal_mask = top_ans != 2000 gt_hard_target_mask = np.ones_like(top_ans, dtype=np.float32) gt_inputs = [images, quest, quest_len, top_ans, gt_hard_target_mask, gt_legal_mask] aug_inputs = [aug_images, aug_quest, aug_quest_len, aug_vqa_labels, hard_target_mask, aug_legal_mask] vqa_inputs = concat_vqa_batch(gt_inputs, aug_inputs) vqa_is_valid = vqa_inputs[-1] vqa_inputs = [_in[vqa_is_valid] for _in in vqa_inputs[:-1]] # remove invalid to save computation # print('Before VQA') # pdb.set_trace() vqa = env.get_vqa_model() vqa.trainstep(vqa_inputs) # update language model # print('Number GT: %d' % is_gt.sum()) num_fake_in_batch = 80 - is_gt.sum() if num_fake_in_batch > 50 or True: # at least half is generated wrapped_gt = _Q_CTX.get_gt_batch(*lm_inputs[2:]) # random sample new corrected_inputs = correct_language_model_inputs(wrapped_sampled + wrapped_gt, is_gt) if num_fake_in_batch > 0: env.lm.trainstep(corrected_inputs) return sess_outputs