def lm_trainstep(train_step, global_step, reader, model, sampler, sess):
    outputs = reader.pop_batch()
    images, quest, quest_len, ans, ans_len = outputs
    quest, quest_len = _Q_CTX.get_gt_batch(quest,
                                           quest_len)  # random sample new

    # random sampling
    noise_vec, pathes, scores = sampler.random_sampling([images, ans, ans_len],
                                                        sess)
    _this_batch_size = images.shape[0]
    scores, pathes, noise = post_process_variation_questions_noise(
        scores, pathes, noise_vec, _this_batch_size)
    # update language model
    # fake = wrap_samples_for_language_model(pathes)
    # real = [quest, quest_len]
    # lm_inputs = fake + real
    if 'CNN' in model.name:
        lm_inputs = wrap_samples_for_language_model(
            sampled=pathes,
            pad_token=sampler.pad_token - 1,
            gts=[quest, quest_len],
            max_length=20)
    else:
        lm_inputs = wrap_samples_for_language_model(
            sampled=pathes,
            pad_token=sampler.pad_token - 1,
            gts=[quest, quest_len])

    loss, gstep = sess.run([train_step, global_step],
                           feed_dict=model.fill_feed_dict(lm_inputs))

    if gstep % 500 == 0:

        def _show_examples(arr, arr_len, name):
            _rewards = model.inference([arr, arr_len])
            ps = _parse_gt_questions(arr, arr_len)
            print('\n%s:' % (name))
            for p, r in zip(ps, _rewards):
                if p[-1] == 2:
                    p = p[:-1]
                sent = _SENT.index_to_question(p)
                print('%s (%0.3f)' % (sent, r))

        fake, fake_len, real, real_len = lm_inputs
        _show_examples(fake, fake_len, 'Fake')
        _show_examples(real, real_len, 'Real')
    return loss, gstep
 def inference(self, ivqa_pathes):
     language_model_inputs = wrap_samples_for_language_model([ivqa_pathes],
                                                             pad_token=self.pad_token - 1,
                                                             max_length=20)
     match_gt = self.eg_lm.query(ivqa_pathes, self.min_gt_count)
     legality_scores = self.nn_lm.inference(language_model_inputs)
     legality_scores[match_gt] = 1.0
     return legality_scores
예제 #3
0
def ivqa_decoding_beam_search(checkpoint_path=None):
    model_config = ModelConfig()
    method = FLAGS.method
    res_file = 'result/bs_gen_%s.json' % method
    score_file = 'result/bs_vqa_scores_%s.mat' % method
    # Get model
    model_fn = get_model_creation_fn('VAQ-Var')
    create_fn = create_reader('VAQ-VVIS', phase='test')

    # Create the vocabulary.
    to_sentence = SentenceGenerator(trainset='trainval')

    # get data reader
    subset = 'kptest'
    reader = create_fn(batch_size=1, subset=subset, version=FLAGS.test_version)

    exemplar = ExemplarLanguageModel()

    if checkpoint_path is None:
        if FLAGS.checkpoint_dir:
            ckpt_dir = FLAGS.checkpoint_dir
        else:
            ckpt_dir = FLAGS.checkpoint_pat % (FLAGS.version, FLAGS.model_type)
        # ckpt_dir = '/import/vision-ephemeral/fl302/models/v2_kpvaq_VAQ-RL/'
        ckpt = tf.train.get_checkpoint_state(ckpt_dir)
        checkpoint_path = ckpt.model_checkpoint_path

    # Build model
    g = tf.Graph()
    with g.as_default():
        # Build the model.ex
        model = model_fn(model_config, 'sampling')
        model.set_num_sampling_points(1000)
        model.build()
        # Restore from checkpoint
        restorer = Restorer(g)
        sess = tf.Session()
        restorer.restore(sess, checkpoint_path)

        # build language model
        language_model = LanguageModel()
        language_model.build()
        language_model.set_cache_dir('test_empty')
        # language_model.set_cache_dir('v1_var_att_lowthresh_cache_restval_VAQ-VarRL')
        language_model.set_session(sess)
        language_model.setup_model()

        # build VQA model
        vqa_model = VQAWrapper(g, sess)
    # vqa_model = MLBWrapper()
    num_batches = reader.num_batches

    print('Running beam search inference...')
    results = []
    batch_vqa_scores = []

    num = FLAGS.max_iters if FLAGS.max_iters > 0 else num_batches
    for i in range(num):

        outputs = reader.get_test_batch()

        # inference
        quest_ids, image_ids = outputs[-2:]
        im, _, _, top_ans, ans_tokens, ans_len = outputs[:-2]
        # pdb.set_trace()
        if top_ans == 2000:
            continue

        print('\n%d/%d' % (i, num))
        question_id = int(quest_ids[0])
        image_id = int(image_ids[0])

        t1 = time()
        pathes, scores = model.greedy_inference([im, ans_tokens, ans_len],
                                                sess)

        # find unique
        ivqa_scores, ivqa_pathes = process_one(scores, pathes)
        t2 = time()
        print('Time for sample generation: %0.2fs' % (t2 - t1))

        # apply language model
        language_model_inputs = wrap_samples_for_language_model(
            [ivqa_pathes], pad_token=model.pad_token - 1, max_length=20)
        match_gt = exemplar.query(ivqa_pathes)
        legality_scores = language_model.inference(language_model_inputs)
        legality_scores[match_gt] = 1.0
        num_keep = max(100, (legality_scores > 0.1).sum())  # no less than 100
        valid_inds = (-legality_scores).argsort()[:num_keep]

        t3 = time()
        print('Time for language model filtration: %0.2fs' % (t3 - t2))

        # for idx in valid_inds:
        #     path = ivqa_pathes[idx]
        #     sc = legality_scores[idx]
        #     sentence = to_sentence.index_to_question(path)
        #     # questions.append(sentence)
        #     print('%s (%0.3f)' % (sentence, sc))

        # apply  VQA model
        sampled = [ivqa_pathes[_idx] for _idx in valid_inds]
        # vqa_scores = vqa_model.get_scores(sampled, image_id, top_ans)
        vqa_scores, is_valid = vqa_model.get_scores(sampled, im, top_ans)
        # conf_inds = (-vqa_scores).argsort()[:20]
        conf_inds = np.where(is_valid)[0]
        # pdb.set_trace()
        # conf_inds = (-vqa_scores).argsort()[:40]

        t4 = time()
        print('Time for VQA verification: %0.2fs' % (t4 - t3))

        this_mean_vqa_score = vqa_scores[conf_inds].mean()
        print('sampled: %d, unique: %d, legal: %d, gt: %d, mean score %0.2f' %
              (pathes.shape[0], len(ivqa_pathes), num_keep, match_gt.sum(),
               this_mean_vqa_score))
        batch_vqa_scores.append(this_mean_vqa_score)

        for _pid, idx in enumerate(conf_inds):
            path = sampled[idx]
            sc = vqa_scores[idx]
            sentence = to_sentence.index_to_question(path)
            aug_quest_id = question_id * 1000 + _pid
            res_i = {
                'image_id': int(image_id),
                'question_id': aug_quest_id,
                'question': sentence,
                'score': float(sc)
            }
            results.append(res_i)

    save_json(res_file, results)
    batch_vqa_scores = np.array(batch_vqa_scores, dtype=np.float32)
    mean_vqa_score = batch_vqa_scores.mean()
    from scipy.io import savemat
    savemat(score_file, {
        'scores': batch_vqa_scores,
        'mean_score': mean_vqa_score
    })
    print('BS mean VQA score: %0.3f' % mean_vqa_score)
    return res_file, mean_vqa_score
def ivqa_decoding_beam_search(checkpoint_path=None):
    model_config = ModelConfig()
    method = FLAGS.method
    res_file = 'result/bs_cand_for_vis.json'
    # Get model
    model_fn = get_model_creation_fn('VAQ-Var')
    create_fn = create_reader('VAQ-VVIS', phase='test')

    # Create the vocabulary.
    to_sentence = SentenceGenerator(trainset='trainval',
                                    top_ans_file='../VQA-tensorflow/data/vqa_trainval_top2000_answers.txt')

    # get data reader
    subset = 'kpval'
    reader = create_fn(batch_size=1, subset=subset,
                       version=FLAGS.test_version)

    exemplar = ExemplarLanguageModel()

    if checkpoint_path is None:
        if FLAGS.checkpoint_dir:
            ckpt_dir = FLAGS.checkpoint_dir
        else:
            ckpt_dir = FLAGS.checkpoint_pat % (FLAGS.version, FLAGS.model_type)
        # ckpt_dir = '/import/vision-ephemeral/fl302/models/v2_kpvaq_VAQ-RL/'
        ckpt = tf.train.get_checkpoint_state(ckpt_dir)
        checkpoint_path = ckpt.model_checkpoint_path

    # Build model
    g = tf.Graph()
    with g.as_default():
        # Build the model.ex
        model = model_fn(model_config, 'sampling')
        model.set_num_sampling_points(5000)
        model.build()
        # Restore from checkpoint
        restorer = Restorer(g)
        sess = tf.Session()
        restorer.restore(sess, checkpoint_path)

        # build language model
        language_model = LanguageModel()
        language_model.build()
        language_model.set_cache_dir('test_empty')
        # language_model.set_cache_dir('v1_var_att_lowthresh_cache_restval_VAQ-VarRL')
        language_model.set_session(sess)
        language_model.setup_model()

        # build VQA model
    # vqa_model = N2MNWrapper()
    # vqa_model = MLBWrapper()
    num_batches = reader.num_batches

    quest_ids_to_vis = {5682052: 'bread',
                        965492: 'plane',
                        681282: 'station'}

    print('Running beam search inference...')
    results = []
    batch_vqa_scores = []

    num = FLAGS.max_iters if FLAGS.max_iters > 0 else num_batches
    for i in range(num):

        outputs = reader.get_test_batch()

        # inference
        quest_ids, image_ids = outputs[-2:]
        quest_id_key = int(quest_ids)

        if quest_id_key not in quest_ids_to_vis:
            continue
        # pdb.set_trace()

        im, gt_q, _, top_ans, ans_tokens, ans_len = outputs[:-2]
        # pdb.set_trace()
        if top_ans == 2000:
            continue

        print('\n%d/%d' % (i, num))
        question_id = int(quest_ids[0])
        image_id = int(image_ids[0])

        t1 = time()
        pathes, scores = model.greedy_inference([im, ans_tokens, ans_len], sess)

        # find unique
        ivqa_scores, ivqa_pathes = process_one(scores, pathes)
        t2 = time()
        print('Time for sample generation: %0.2fs' % (t2 - t1))

        # apply language model
        language_model_inputs = wrap_samples_for_language_model([ivqa_pathes],
                                                                pad_token=model.pad_token - 1,
                                                                max_length=20)
        match_gt = exemplar.query(ivqa_pathes)
        legality_scores = language_model.inference(language_model_inputs)
        legality_scores[match_gt] = 1.0
        num_keep = max(100, (legality_scores > 0.1).sum())  # no less than 100
        valid_inds = (-legality_scores).argsort()[:num_keep]
        print('keep: %d/%d' % (num_keep, len(ivqa_pathes)))

        t3 = time()
        print('Time for language model filtration: %0.2fs' % (t3 - t2))

        def token_arr_to_list(arr):
            return arr.flatten().tolist()

        for _pid, idx in enumerate(valid_inds):
            path = ivqa_pathes[idx]
            # sc = vqa_scores[idx]
            sentence = to_sentence.index_to_question(path)
            aug_quest_id = question_id * 1000 + _pid
            res_i = {'image_id': int(image_id),
                     'aug_id': aug_quest_id,
                     'question_id': question_id,
                     'target': sentence,
                     'top_ans_id': int(top_ans),
                     'question': to_sentence.index_to_question(token_arr_to_list(gt_q)),
                     'answer': to_sentence.index_to_answer(token_arr_to_list(ans_tokens))}
            results.append(res_i)

    save_json(res_file, results)
    return None
예제 #5
0
def reinforce_trainstep(reader_outputs, model, env, sess, task_ops,
                        _VQA_Belief):
    # reader_outputs = reader.pop_batch()
    # quest_ids, images, quest, quest_len, top_ans, ans, ans_len = reader_outputs
    # select the first image
    # idx = 0
    #
    # def _reshape_array(v):
    #     if type(v) == np.ndarray:
    #         return v[np.newaxis, :]
    #     else:
    #         return np.reshape(v, (1,))
    #
    # selected = [_reshape_array(v[idx]) for v in reader_outputs]
    res5c, images, quest, quest_len, top_ans, ans, ans_len, quest_ids, image_ids = reader_outputs
    # random sampling
    noise_vec, pathes, scores = model.random_sampling([images, ans, ans_len],
                                                      sess)
    _this_batch_size = images.shape[0]
    scores, pathes, noise = post_process_variation_questions_noise(
        scores, pathes, noise_vec, _this_batch_size, find_unique=False)

    lm_inputs = wrap_samples_for_language_model(sampled=pathes,
                                                pad_token=model.pad_token - 1,
                                                gts=[quest, quest_len],
                                                max_length=20)

    def _show_examples(arr, arr_len, _rewards, name):
        ps = _parse_gt_questions(arr, arr_len)
        print('\n%s:' % (name))
        for p, r in zip(ps, _rewards):
            if p[-1] == 2:
                p = p[:-1]
            sent = env.to_sentence.index_to_question(p)
            print('%s (%d)' % (sent, r))

    # compute reward
    vqa_inputs = [images, res5c, ans, ans_len, top_ans]
    # lm_inputs = lm_inputs[:2]
    wrapped_sampled = lm_inputs[:2]
    rewards, rewards_all, is_gt, aug_data = env.get_reward(
        pathes, [quest, quest_len],
        [vqa_inputs, wrapped_sampled, scores, quest_ids])

    max_path_arr, max_path_len, max_noise, max_rewards = \
        prepare_reinforce_data(pathes, noise, rewards, pad_token=model.pad_token)

    vqa_scores = rewards_all[:, 0]
    language_scores = rewards_all[:, 2]
    # scores = vqa_scores * (language_scores > 0.5)
    scores = vqa_scores * (language_scores > env.language_thresh)
    new_pathes = _parse_gt_questions(max_path_arr, max_path_len)
    _VQA_Belief.insert(new_pathes, scores)

    # _show_examples(max_path_arr, max_path_len, is_gt, 'Sampled')
    # pdb.set_trace()

    aug_images, aug_ans, aug_ans_len, is_in_vocab = aug_data
    sess_in = [
        aug_images, max_path_arr, max_path_len, aug_ans, aug_ans_len,
        max_noise, max_rewards, rewards_all
    ]
    sess_in = [_in[is_in_vocab] for _in in sess_in]  # remove oov
    avg_reward = max_rewards.mean()

    # train op
    sess_outputs = sess.run(task_ops, feed_dict=model.fill_feed_dict(sess_in))
    sess_outputs += [avg_reward, 'reward']

    # update language model
    # print('Number GT: %d' % is_gt.sum())
    # num_fake_in_batch = 80 - is_gt.sum()
    if False:  # at least half is generated
        wrapped_gt = _Q_CTX.get_gt_batch(*lm_inputs[2:])  # random sample new
        corrected_inputs = correct_language_model_inputs(
            wrapped_sampled + wrapped_gt, is_gt)
        # num_fake = corrected_inputs[0].shape[0]
        # num_real = corrected_inputs[2].shape[0]
        # print('Num positive: %d, num negative %d' % (num_real, num_fake))

        # _show_examples(corrected_inputs[0], corrected_inputs[1], np.zeros_like(corrected_inputs[1]), 'Fake')
        # _show_examples(corrected_inputs[2], corrected_inputs[3], np.zeros_like(corrected_inputs[3]), 'Real')
        # pdb.set_trace()

        if min(wrapped_sampled[1].size, wrapped_gt[1].size) > 0:
            env.lm.trainstep(corrected_inputs)
    # _VQA_Belief.vertify_vqa(env, vqa_inputs)
    return sess_outputs
예제 #6
0
def ivqa_decoding_beam_search(checkpoint_path=None):
    model_config = ModelConfig()
    method = FLAGS.method
    res_file = 'result/bs_gen_%s.json' % method
    score_file = 'result/bs_vqa_scores_%s.mat' % method
    # Get model
    model_fn = get_model_creation_fn('VAQ-Var')
    create_fn = create_reader('VAQ-VVIS', phase='test')

    # Create the vocabulary.
    to_sentence = SentenceGenerator(trainset='trainval')

    # get data reader
    subset = 'kptrain'
    reader = create_fn(batch_size=1, subset=subset, version=FLAGS.test_version)

    exemplar = ExemplarLanguageModel()

    if checkpoint_path is None:
        if FLAGS.checkpoint_dir:
            ckpt_dir = FLAGS.checkpoint_dir
        else:
            ckpt_dir = FLAGS.checkpoint_pat % (FLAGS.version, FLAGS.model_type)
        # ckpt_dir = '/import/vision-ephemeral/fl302/models/v2_kpvaq_VAQ-RL/'
        ckpt = tf.train.get_checkpoint_state(ckpt_dir)
        checkpoint_path = ckpt.model_checkpoint_path

    # Build model
    g = tf.Graph()
    with g.as_default():
        # Build the model.ex
        model = model_fn(model_config, 'sampling')
        model.set_num_sampling_points(5)
        model.build()
        # Restore from checkpoint
        restorer = Restorer(g)
        sess = tf.Session()
        restorer.restore(sess, checkpoint_path)

        # build language model
        language_model = LanguageModel()
        language_model.build()
        language_model.set_cache_dir('test_empty')
        # language_model.set_cache_dir('v1_var_att_lowthresh_cache_restval_VAQ-VarRL')
        language_model.set_session(sess)
        language_model.setup_model()

    num_batches = reader.num_batches

    print('Running beam search inference...')

    num = FLAGS.max_iters if FLAGS.max_iters > 0 else num_batches
    neg_pathes = []
    need_stop = False
    for i in range(num):

        outputs = reader.get_test_batch()

        # inference
        im, _, _, top_ans, ans_tokens, ans_len = outputs[:-2]
        if top_ans == 2000:
            continue

        print('\n%d/%d' % (i, num))

        t1 = time()
        pathes, scores = model.greedy_inference([im, ans_tokens, ans_len],
                                                sess)

        # find unique
        ivqa_scores, ivqa_pathes = process_one(scores, pathes)
        t2 = time()
        print('Time for sample generation: %0.2fs' % (t2 - t1))

        # apply language model
        language_model_inputs = wrap_samples_for_language_model(
            [ivqa_pathes], pad_token=model.pad_token - 1, max_length=20)
        match_gt = exemplar.query(ivqa_pathes)
        legality_scores = language_model.inference(language_model_inputs)
        legality_scores[match_gt] = 1.0

        neg_inds = np.where(legality_scores < 0.2)[0]
        for idx in neg_inds:
            ser_neg = serialize_path(ivqa_pathes[idx][1:])
            neg_pathes.append(ser_neg)
            if len(neg_pathes) > 100000:
                need_stop = True
                break
            # if len(neg_pathes) > 1000:
            #     need_stop = True
            #     break
            # print('Neg size: %d' % len(neg_pathes))
        if need_stop:
            break
    sv_file = 'data/lm_init_neg_pathes.json'
    save_json(sv_file, neg_pathes)
def ivqa_decoding_beam_search(ckpt_dir, method):
    model_config = ModelConfig()
    inf_type = 'beam'
    assert (inf_type in ['beam', 'rand'])
    # method = FLAGS.method
    if inf_type == 'rand':
        res_file = 'result/bs_RL2_cands_LM_%s.json' % method
    else:
        res_file = 'result/bs_RL2_cands_LM_%s_BEAM.json' % method
    if os.path.exists(res_file):
        print('File %s already exist, skipped' % res_file)
        return
    # score_file = 'result/bs_vqa_scores_%s.mat' % method
    # Get model
    model_fn = get_model_creation_fn('VAQ-Var')
    create_fn = create_reader('VAQ-VVIS', phase='test')

    # Create the vocabulary.
    to_sentence = SentenceGenerator(trainset='trainval')

    # get data reader
    subset = 'bs_test'
    reader = create_fn(batch_size=1, subset=subset,
                       version=FLAGS.test_version)

    exemplar = ExemplarLanguageModel()

    # if checkpoint_path is None:
    #     if FLAGS.checkpoint_dir:
    #         ckpt_dir = FLAGS.checkpoint_dir
    #     else:
    #         ckpt_dir = FLAGS.checkpoint_pat % (FLAGS.version, FLAGS.model_type)
    # ckpt_dir = '/import/vision-ephemeral/fl302/models/v2_kpvaq_VAQ-RL/'
    ckpt = tf.train.get_checkpoint_state(ckpt_dir)
    checkpoint_path = ckpt.model_checkpoint_path

    # Build model
    g = tf.Graph()
    with g.as_default():
        # Build the model.ex
        if inf_type == 'rand':
            model = model_fn(model_config, 'sampling')
            model.set_num_sampling_points(1000)
        else:
            model = model_fn(model_config, 'sampling_beam')
            model.set_num_sampling_points(1000)
        model.build()
        # Restore from checkpoint
        restorer = Restorer(g)
        sess = tf.Session()
        restorer.restore(sess, checkpoint_path)

        # build language model
        language_model = LanguageModel()
        language_model.build()
        language_model.set_cache_dir('test_empty')
        # language_model.set_cache_dir('v1_var_att_lowthresh_cache_restval_VAQ-VarRL')
        language_model.set_session(sess)
        language_model.setup_model()

        # build VQA model
    # vqa_model = N2MNWrapper()
    # vqa_model = MLBWrapper()
    num_batches = reader.num_batches

    print('Running beam search inference...')
    results = {}
    # batch_vqa_scores = []

    num = FLAGS.max_iters if FLAGS.max_iters > 0 else num_batches
    for i in range(num):
        outputs = reader.get_test_batch()

        # inference
        quest_ids, image_ids = outputs[-2:]
        im, _, _, top_ans, ans_tokens, ans_len = outputs[:-2]
        # pdb.set_trace()
        if top_ans == 2000:
            continue

        print('\n%d/%d' % (i, num))
        question_id = int(quest_ids[0])
        image_id = int(image_ids[0])

        t1 = time()
        pathes, scores = model.greedy_inference([im, ans_tokens, ans_len], sess)

        # find unique
        ivqa_scores, ivqa_pathes = process_one(scores, pathes)
        t2 = time()
        print('Time for sample generation: %0.2fs' % (t2 - t1))

        # apply language model
        language_model_inputs = wrap_samples_for_language_model([ivqa_pathes],
                                                                pad_token=model.pad_token - 1,
                                                                max_length=20)
        match_gt = exemplar.query(ivqa_pathes)
        legality_scores = language_model.inference(language_model_inputs)
        legality_scores[match_gt] = 1.0
        num_keep = max(100, (legality_scores > 0.3).sum())  # no less than 100
        valid_inds = (-legality_scores).argsort()[:num_keep]

        t3 = time()
        print('Time for language model filtration: %0.2fs' % (t3 - t2))

        # for idx in valid_inds:
        #     path = ivqa_pathes[idx]
        #     sc = legality_scores[idx]
        #     sentence = to_sentence.index_to_question(path)
        #     # questions.append(sentence)
        #     print('%s (%0.3f)' % (sentence, sc))

        # apply  VQA model
        sampled = [ivqa_pathes[_idx] for _idx in valid_inds]
        legality_scores = legality_scores[valid_inds]

        result_key = int(question_id)
        tmp = []
        for idx, path in enumerate(sampled):
            # path = sampled[idx]
            sc = legality_scores[idx]
            sentence = to_sentence.index_to_question(path)
            # aug_quest_id = question_id * 1000 + _pid
            res_i = {'image_id': int(image_id),
                     'aug_id': idx,
                     'question_id': question_id,
                     'question': sentence,
                     'score': float(sc)}
            tmp.append(res_i)
        print('Number of unique questions: %d' % len(tmp))
        results[result_key] = tmp

    save_json(res_file, results)
예제 #8
0
def reinforce_trainstep(reader, model, env, sess, task_ops):
    outputs = reader.pop_batch()
    images, quest, quest_len, ans, ans_len = outputs
    # random sampling
    noise_vec, pathes, scores = model.random_sampling([images, ans, ans_len],
                                                      sess)
    _this_batch_size = images.shape[0]
    scores, pathes, noise = post_process_variation_questions_noise(
        scores, pathes, noise_vec, _this_batch_size, find_unique=False)
    # diverse_rewards = env.diversity_reward.get_reward(pathes, scores)
    # update language model
    # lm = env.lm
    # fake = wrap_samples_for_language_model(pathes)
    # real = [quest, quest_len]
    # lm_inputs = fake + real
    lm_inputs = wrap_samples_for_language_model(sampled=pathes,
                                                pad_token=model.pad_token - 1,
                                                gts=[quest, quest_len],
                                                max_length=20)

    def _show_examples(arr, arr_len, _rewards, name):
        ps = _parse_gt_questions(arr, arr_len)
        print('\n%s:' % (name))
        for p, r in zip(ps, _rewards):
            if p[-1] == 2:
                p = p[:-1]
            sent = env.to_sentence.index_to_question(p)
            print('%s (%d)' % (sent, r))

    # compute reward
    vqa_inputs = [images, ans, ans_len]
    # lm_inputs = lm_inputs[:2]
    wrapped_sampled = lm_inputs[:2]
    rewards, rewards_all, is_gt, aug_data = env.get_reward(
        pathes, [quest, quest_len], [vqa_inputs, wrapped_sampled, scores])

    max_path_arr, max_path_len, max_noise, max_rewards = \
        prepare_reinforce_data(pathes, noise, rewards, pad_token=model.pad_token)

    # _show_examples(max_path_arr, max_path_len, is_gt, 'Sampled')
    # pdb.set_trace()

    aug_images, aug_ans, aug_ans_len, is_in_vocab = aug_data
    sess_in = [
        aug_images, max_path_arr, max_path_len, aug_ans, aug_ans_len,
        max_noise, max_rewards, rewards_all
    ]
    sess_in = [_in[is_in_vocab] for _in in sess_in]  # remove oov
    avg_reward = max_rewards.mean()

    # train op
    sess_outputs = sess.run(task_ops, feed_dict=model.fill_feed_dict(sess_in))
    sess_outputs += [avg_reward, 'reward']

    # update language model
    # print('Number GT: %d' % is_gt.sum())
    num_fake_in_batch = 80 - is_gt.sum()
    if num_fake_in_batch > 50 or True:  # at least half is generated
        wrapped_gt = _Q_CTX.get_gt_batch(*lm_inputs[2:])  # random sample new
        corrected_inputs = correct_language_model_inputs(
            wrapped_sampled + wrapped_gt, is_gt)
        # num_fake = corrected_inputs[0].shape[0]
        # num_real = corrected_inputs[2].shape[0]
        # print('Num positive: %d, num negative %d' % (num_real, num_fake))

        # _show_examples(corrected_inputs[0], corrected_inputs[1], np.zeros_like(corrected_inputs[1]), 'Fake')
        # _show_examples(corrected_inputs[2], corrected_inputs[3], np.zeros_like(corrected_inputs[3]), 'Real')
        # pdb.set_trace()
        if num_fake_in_batch > 0:
            env.lm.trainstep(corrected_inputs)
    return sess_outputs
def reinforce_trainstep(reader, model, env, sess, task_ops):
    outputs = reader.pop_batch()
    quest_ids, images, quest, quest_len, top_ans, ans, ans_len = outputs
    # random sampling
    noise_vec, pathes, scores = model.random_sampling([images, ans, ans_len], sess)
    _this_batch_size = images.shape[0]
    scores, pathes, noise = post_process_variation_questions_noise(scores,
                                                                   pathes,
                                                                   noise_vec,
                                                                   _this_batch_size,
                                                                   find_unique=False)

    lm_inputs = wrap_samples_for_language_model(sampled=pathes,
                                                pad_token=model.pad_token - 1,
                                                gts=[quest, quest_len],
                                                max_length=20)

    def _show_examples(arr, arr_len, _rewards, name):
        ps = _parse_gt_questions(arr, arr_len)
        print('\n%s:' % (name))
        for p, r in zip(ps, _rewards):
            if p[-1] == 2:
                p = p[:-1]
            sent = env.to_sentence.index_to_question(p)
            print('%s (%d)' % (sent, r))

    # compute reward
    vqa_inputs = [images, ans, ans_len, top_ans]
    # lm_inputs = lm_inputs[:2]
    wrapped_sampled = lm_inputs[:2]
    rewards, rewards_all, is_gt, aug_data = env.get_reward(pathes, [quest, quest_len],
                                                           [vqa_inputs, wrapped_sampled,
                                                            scores, quest_ids])

    max_path_arr, max_path_len, max_noise, max_rewards = \
        prepare_reinforce_data(pathes, noise, rewards, pad_token=model.pad_token)

    # _show_examples(max_path_arr, max_path_len, is_gt, 'Sampled')
    # pdb.set_trace()

    aug_images, aug_quest, aug_quest_len, aug_ans, aug_ans_len, aug_top_ans, is_in_vocab = aug_data
    sess_in = [aug_images, max_path_arr, max_path_len, aug_ans, aug_ans_len,
               max_noise, max_rewards, rewards_all]
    sess_in = [_in[is_in_vocab] for _in in sess_in]  # remove oov
    avg_reward = max_rewards.mean()

    # train op
    sess_outputs = sess.run(task_ops, feed_dict=model.fill_feed_dict(sess_in))
    sess_outputs += [avg_reward, 'reward']

    # update VQA model
    aug_legal_mask, aug_vqa_labels, hard_target_mask = correct_vqa_labels(aug_top_ans, rewards_all, is_in_vocab)
    gt_legal_mask = top_ans != 2000
    gt_hard_target_mask = np.ones_like(top_ans, dtype=np.float32)
    gt_inputs = [images, quest, quest_len, top_ans, gt_hard_target_mask, gt_legal_mask]
    aug_inputs = [aug_images, aug_quest, aug_quest_len, aug_vqa_labels, hard_target_mask, aug_legal_mask]
    vqa_inputs = concat_vqa_batch(gt_inputs, aug_inputs)
    vqa_is_valid = vqa_inputs[-1]
    vqa_inputs = [_in[vqa_is_valid] for _in in vqa_inputs[:-1]]  # remove invalid to save computation

    # print('Before VQA')
    # pdb.set_trace()
    vqa = env.get_vqa_model()
    vqa.trainstep(vqa_inputs)

    # update language model
    # print('Number GT: %d' % is_gt.sum())
    num_fake_in_batch = 80 - is_gt.sum()
    if num_fake_in_batch > 50 or True:  # at least half is generated
        wrapped_gt = _Q_CTX.get_gt_batch(*lm_inputs[2:])  # random sample new
        corrected_inputs = correct_language_model_inputs(wrapped_sampled + wrapped_gt, is_gt)
        if num_fake_in_batch > 0:
            env.lm.trainstep(corrected_inputs)
    return sess_outputs