コード例 #1
0
    def _run_test_repeat(self, tmpdir: str, fake_input: FakeInput):
        outfile = os.path.join(tmpdir, 'log.jsonl')
        Interactive.main(model='repeat_query', outfile=outfile)

        log = conversations.Conversations(outfile)
        self.assertEqual(len(log), fake_input.max_episodes)
        for entry in log:
            self.assertEqual(len(entry), 2 * fake_input.max_turns)
コード例 #2
0
 def test_constrained_beam_search(self):
     # call it with particular args
     Interactive.main(
         model='transformer/generator',
         task='blended_skill_talk',
         include_personas=False,
         include_initial_utterances=False,
         #history_size=2,
         beam_size=4,
         beam_min_length=20,
         beam_context_block_ngram=0,
         beam_block_ngram=3,
         inference='constrainedbeam',
         #inference='beam',
         model_parallel=True,
         # the model_file is a filename path pointing to a particular model dump.
         # Model files that begin with "zoo:" are special files distributed by the ParlAI team.
         # They'll be automatically downloaded when you ask to use them.
         #model_file = "zoo:blender/blender_3B/model"
         model_file="zoo:blender/blender_90M/model")
コード例 #3
0
                            'w') as qrerankerinput:
                        for rq in retrieved_question:
                            qrerankerinput.write(rq)
                            qrerankerinput.write('\n')
                    with open(
                            '/raid/zhenduow/conversationalQA/data/' +
                            exp_name + '/retrieved_answer',
                            'w') as arerankerinput:
                        for ra in retrieved_answer:
                            arerankerinput.write(ra)
                            arerankerinput.write('\n')

                    # get reranker results
                    question = Interactive.main(model = 'transformer/polyencoder', \
                        model_file = 'zoo:pretrained_transformers/model_poly/msdialogquestion',  \
                        encode_candidate_vecs = 'true', eval_candidates ='fixed',  \
                        fixed_candidates_path = '/raid/zhenduow/conversationalQA/data/' + exp_name + '/retrieved_question', \
                            force_fp16_tokens = True,\
                        human_input=obs, fixed_candidate_vecs = 'replace')
                    answer = Interactive.main(model = 'transformer/polyencoder', \
                        model_file = 'zoo:pretrained_transformers/model_poly/msdialoganswer',  \
                        encode_candidate_vecs = 'true', eval_candidates ='fixed',  \
                        fixed_candidates_path = '/raid/zhenduow/conversationalQA/data/' + exp_name + '/retrieved_answer', \
                        human_input=obs,  fixed_candidate_vecs = 'replace')

                    memory[obs] = [question, answer]

                action = agent.choose_action(obs, question, answer)
                obs_, question_reward, q_done, good_question = user.update_state(
                    conv_id, obs, 1, question, answer, use_top_k=use_top_k)
                _, answer_reward, _, _ = user.update_state(conv_id,
                                                           obs,
コード例 #4
0
 def test_repeat(self):
     Interactive.main(model='repeat_query',
                      task='convai2',
                      datatype='valid')
コード例 #5
0
 def test_repeat(self):
     Interactive.main(model='repeat_query')
コード例 #6
0
def main():
    model_path = '/home/warvisionary/parlai_transfer/from_pretrained/model'
    print(f"Interacting with model at: {model_path}")
    Interactive.main(task='empathetic_dialogues_ru', model_file=model_path)
コード例 #7
0
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Basic example which allows local human keyboard input to talk to a trained model.

For documentation, see parlai.scripts.interactive.
"""
import random
from parlai.scripts.interactive import Interactive

if __name__ == '__main__':
    random.seed(42)
    Interactive.main()
コード例 #8
0
# Import the Interactive script
from parlai.scripts.interactive import Interactive

# call it with particular args
Interactive.main(
    # the model_file is a filename path pointing to a particular model dump.
    # Model files that begin with "zoo:" are special files distributed by the ParlAI team.
    # They'll be automatically downloaded when you ask to use them.
    model_file='zoo:tutorial_transformer_generator/model')
コード例 #9
0
def main(args):
    logging.getLogger().setLevel(logging.INFO)
    limit_memory(1e11)

    random.seed(2020)
    if args.cv != -1:
        train_dataset = ConversationDataset(
            'data/' + args.dataset_name + '-Complete/train' + str(args.cv) +
            '/', batch_size, max_train_size)
        test_dataset = ConversationDataset(
            'data/' + args.dataset_name + '-Complete/test' + str(args.cv) +
            '/', batch_size, max_test_size)
    else:
        train_dataset = ConversationDataset(
            'data/' + args.dataset_name + '-Complete/train/', batch_size,
            max_train_size)
        test_dataset = ConversationDataset(
            'data/' + args.dataset_name + '-Complete/test/', batch_size,
            max_test_size)
    train_size = sum(
        [len(b['conversations'].keys()) for b in train_dataset.batches])
    test_size = sum(
        [len(b['conversations'].keys()) for b in test_dataset.batches])
    agent = Agent(lr=1e-4,
                  input_dims=(3 + args.topn) * observation_dim + 1 + args.topn,
                  top_k=args.topn,
                  n_actions=action_num,
                  gamma=agent_gamma,
                  weight_decay=0.01)
    score_agent = ScoreAgent(lr=1e-4,
                             input_dims=1 + args.topn,
                             top_k=args.topn,
                             n_actions=action_num,
                             gamma=agent_gamma,
                             weight_decay=0.0)
    text_agent = TextAgent(lr=1e-4,
                           input_dims=(3 + args.topn) * observation_dim,
                           top_k=args.topn,
                           n_actions=action_num,
                           gamma=agent_gamma,
                           weight_decay=0.01)
    base_agent = BaseAgent(lr=1e-4,
                           input_dims=2 * observation_dim,
                           n_actions=2,
                           weight_decay=0.01)

    if args.dataset_name == 'MSDialog':
        reranker_prefix = ''
    elif args.dataset_name == 'UDC':
        reranker_prefix = 'udc'
    elif args.dataset_name == 'Opendialkg':
        reranker_prefix = 'open'
    # create rerankers
    if args.reranker_name == 'Poly':
        question_reranker = Interactive.main(model = 'transformer/polyencoder', \
                            model_file = 'zoo:pretrained_transformers/model_poly/' + reranker_prefix + 'question',  \
                            encode_candidate_vecs = False,  eval_candidates = 'inline', interactive_candidates = 'inline',
                            return_cand_scores = True)
        answer_reranker = Interactive.main(model = 'transformer/polyencoder', \
                            model_file = 'zoo:pretrained_transformers/model_poly/' + reranker_prefix + 'answer',  \
                            encode_candidate_vecs = False,  eval_candidates = 'inline', interactive_candidates = 'inline',
                            return_cand_scores = True)
        print("Loading rerankers:", 'model_poly/' + reranker_prefix + 'answer',
              'model_poly/' + reranker_prefix + 'question')
    elif args.reranker_name == 'Bi':
        bi_question_reranker = Interactive.main(model = 'transformer/biencoder', \
                            model_file = 'zoo:pretrained_transformers/model_bi/' + reranker_prefix + 'question',  \
                            encode_candidate_vecs = False,  eval_candidates = 'inline', interactive_candidates = 'inline',
                            return_cand_scores = True)
        bi_answer_reranker = Interactive.main(model = 'transformer/biencoder', \
                            model_file = 'zoo:pretrained_transformers/model_bi/' + reranker_prefix + 'answer',  \
                            encode_candidate_vecs = False,  eval_candidates = 'inline', interactive_candidates = 'inline',
                            return_cand_scores = True)
        print("Loading rerankers:", 'model_bi/' + reranker_prefix + 'answer',
              'model_bi/' + reranker_prefix + 'question')

    # embedding model
    tokenizer = AutoTokenizer.from_pretrained('xlnet-base-cased')
    embedding_model = AutoModel.from_pretrained('xlnet-base-cased')
    '''
    local_vars = list(locals().items())
    for var, obj in local_vars:
        print(var, get_size(obj))
    '''

    if not os.path.exists(args.dataset_name + '_experiments/embedding_cache/'):
        os.makedirs(args.dataset_name + '_experiments/embedding_cache/')
    if not os.path.exists(args.dataset_name + '_experiments/embedding_cache/' +
                          args.reranker_name):
        os.makedirs(args.dataset_name + '_experiments/embedding_cache/' +
                    args.reranker_name)
    if args.cv != -1:
        if not os.path.exists(args.dataset_name +
                              '_experiments/embedding_cache/' +
                              args.reranker_name + '/' + str(args.cv)):
            os.makedirs(args.dataset_name + '_experiments/embedding_cache/' +
                        args.reranker_name + '/' + str(args.cv))
            os.makedirs(args.dataset_name + '_experiments/embedding_cache/' +
                        args.reranker_name + '/' + str(args.cv) + '/train')
            os.makedirs(args.dataset_name + '_experiments/embedding_cache/' +
                        args.reranker_name + '/' + str(args.cv) + '/test')
    else:
        if not os.path.exists(args.dataset_name +
                              '_experiments/embedding_cache/' +
                              args.reranker_name + '/train'):
            os.makedirs(args.dataset_name + '_experiments/embedding_cache/' +
                        args.reranker_name + '/train')
        if not os.path.exists(args.dataset_name +
                              '_experiments/embedding_cache/' +
                              args.reranker_name + '/test'):
            os.makedirs(args.dataset_name + '_experiments/embedding_cache/' +
                        args.reranker_name + '/test')

    for i in range(train_iter):
        train_scores, train_q0_scores, train_q1_scores, train_q2_scores, train_oracle_scores, train_base_scores, train_score_scores, train_text_scores = [],[],[],[],[],[],[],[]
        train_worse, train_q0_worse, train_q1_worse, train_q2_worse, train_oracle_worse, train_base_worse, train_score_worse, train_text_worse = [],[],[],[],[],[],[],[]
        #train_correct, train_q0_correct, train_q1_correct, train_q2_correct, train_oracle_correct, train_base_correct, train_score_correct,train_text_correct = [],[],[],[],[],[],[],[]
        for batch_serial, batch in enumerate(train_dataset.batches):
            print(dict(psutil.virtual_memory()._asdict()))
            if args.cv != -1:
                if os.path.exists(args.dataset_name +
                                  '_experiments/embedding_cache/' +
                                  args.reranker_name + '/' + str(args.cv) +
                                  '/train/memory.batchsave' +
                                  str(batch_serial)):
                    with T.no_grad():
                        memory = T.load(args.dataset_name +
                                        '_experiments/embedding_cache/' +
                                        args.reranker_name + '/' +
                                        str(args.cv) +
                                        '/train/memory.batchsave' +
                                        str(batch_serial))
                else:
                    memory = {}
            else:
                if os.path.exists(args.dataset_name +
                                  '_experiments/embedding_cache/' +
                                  args.reranker_name +
                                  '/train/memory.batchsave' +
                                  str(batch_serial)):
                    with T.no_grad():
                        memory = T.load(args.dataset_name +
                                        '_experiments/embedding_cache/' +
                                        args.reranker_name +
                                        '/train/memory.batchsave' +
                                        str(batch_serial))
                else:
                    memory = {}
            train_ids = list(batch['conversations'].keys())
            user = User(batch['conversations'],
                        cq_reward=cq_reward,
                        cq_penalty=cq_penalty)
            for conv_serial, train_id in enumerate(train_ids):
                query = user.initialize_state(train_id)
                if query == '':  # UDC dataset has some weird stuff
                    continue
                context = ''
                ignore_questions = []
                n_round = 0
                patience_used = 0
                q_done = False
                stop, base_stop, score_stop, text_stop = False, False, False, False
                print(
                    '-------- train batch %.0f conversation %.0f/%.0f --------'
                    % (batch_serial, batch_size *
                       (batch_serial) + conv_serial + 1, train_size))

                while not q_done:
                    total_tic = time.perf_counter()
                    print('-------- round %.0f --------' % (n_round))
                    if query in memory.keys():
                        if context not in memory[query].keys():
                            # sampling
                            question_candidates = generate_batch_question_candidates(
                                batch, train_id, ignore_questions, batch_size)
                            answer_candidates = generate_batch_answer_candidates(
                                batch, train_id, batch_size)
                            # get reranker results
                            if args.reranker_name == 'Poly':
                                questions, questions_scores = rerank(
                                    question_reranker, query, context,
                                    question_candidates)
                                answers, answers_scores = rerank(
                                    answer_reranker, query, context,
                                    answer_candidates)
                            elif args.reranker_name == 'Bi':
                                questions, questions_scores = rerank(
                                    bi_question_reranker, query, context,
                                    question_candidates)
                                answers, answers_scores = rerank(
                                    bi_answer_reranker, query, context,
                                    answer_candidates)

                            memory = save_to_memory(query, context, memory,
                                                    questions, answers,
                                                    questions_scores,
                                                    answers_scores, tokenizer,
                                                    embedding_model)

                    else:
                        # sampling
                        question_candidates = generate_batch_question_candidates(
                            batch, train_id, ignore_questions, batch_size)
                        answer_candidates = generate_batch_answer_candidates(
                            batch, train_id, batch_size)
                        # get reranker results
                        if args.reranker_name == 'Poly':
                            questions, questions_scores = rerank(
                                question_reranker, query, context,
                                question_candidates)
                            answers, answers_scores = rerank(
                                answer_reranker, query, context,
                                answer_candidates)
                        elif args.reranker_name == 'Bi':
                            questions, questions_scores = rerank(
                                bi_question_reranker, query, context,
                                question_candidates)
                            answers, answers_scores = rerank(
                                bi_answer_reranker, query, context,
                                answer_candidates)

                        memory = save_to_memory(query, context, memory,
                                                questions, answers,
                                                questions_scores,
                                                answers_scores, tokenizer,
                                                embedding_model)

                    query_embedding, context_embedding, questions, answers, questions_embeddings, answers_embeddings, questions_scores, answers_scores = read_from_memory(
                        query, context, memory)
                    action = agent.choose_action(query_embedding,
                                                 context_embedding,
                                                 questions_embeddings,
                                                 answers_embeddings,
                                                 questions_scores,
                                                 answers_scores)
                    base_action = base_agent.choose_action(
                        query_embedding, context_embedding)
                    score_action = score_agent.choose_action(
                        questions_scores, answers_scores)
                    text_action = text_agent.choose_action(
                        query_embedding, context_embedding,
                        questions_embeddings, answers_embeddings)

                    evaluation_tic = time.perf_counter()
                    #context_, question_reward, q_done, good_question, patience_this_turn = user.update_state(train_id, context, 1, questions, answers, use_top_k = args.topn - patience_used)
                    context_, question_reward, q_done, good_question, patience_this_turn = user.update_state(
                        train_id,
                        context,
                        1,
                        questions,
                        answers,
                        use_top_k=args.topn)
                    patience_used = max(patience_used + patience_this_turn,
                                        args.topn)
                    _, answer_reward, _, _, _ = user.update_state(
                        train_id,
                        context,
                        0,
                        questions,
                        answers,
                        use_top_k=args.topn - patience_used)
                    action_reward = [answer_reward, question_reward][action]
                    evaluation_toc = time.perf_counter()
                    print('action', action, 'base_action', base_action,
                          'score_action', score_action, 'text_action',
                          text_action, 'answer reward', answer_reward,
                          'question reward', question_reward, 'q done', q_done)

                    if n_round >= max_round:
                        q_done = True

                    if not q_done:
                        ignore_questions.append(good_question)
                        if context_ not in memory[query].keys():
                            # sampling
                            question_candidates = generate_batch_question_candidates(
                                batch, train_id, ignore_questions, batch_size)
                            answer_candidates = generate_batch_answer_candidates(
                                batch, train_id, batch_size)

                            # get reranker results
                            if args.reranker_name == 'Poly':
                                questions_, questions_scores_ = rerank(
                                    question_reranker, query, context_,
                                    question_candidates)
                                answers_, answers_scores_ = rerank(
                                    answer_reranker, query, context_,
                                    answer_candidates)
                            elif args.reranker_name == 'Bi':
                                questions_, questions_scores_ = rerank(
                                    bi_question_reranker, query, context_,
                                    question_candidates)
                                answers_, answers_scores_ = rerank(
                                    bi_answer_reranker, query, context_,
                                    answer_candidates)

                            memory = save_to_memory(query, context_, memory,
                                                    questions_, answers_,
                                                    questions_scores_,
                                                    answers_scores_, tokenizer,
                                                    embedding_model)
                        query_embedding, context_embedding_, questions_, answers_, questions_embeddings_, answers_embeddings_, questions_scores_, answers_scores_ = read_from_memory(
                            query, context_, memory)

                    else:
                        context_embedding_ = generate_embedding_no_grad(
                            context_, tokenizer, embedding_model)
                        questions_, answers_, questions_embeddings_, answers_embeddings_, questions_scores_, answers_scores_ = None, None, None, None, None, None

                    agent.joint_learn((query_embedding, context_embedding, questions_embeddings, answers_embeddings, questions_scores, answers_scores),\
                        answer_reward, question_reward,\
                        (query_embedding, context_embedding_, questions_embeddings_, answers_embeddings_, questions_scores_, answers_scores_))
                    base_agent.learn(
                        query_embedding, context_embedding, 0 if
                        (n_round +
                         1) == len(user.dataset[train_id]) / 2 else 1)
                    score_agent.joint_learn((questions_scores, answers_scores),\
                        answer_reward, question_reward,\
                        (questions_scores_, answers_scores_))
                    text_agent.joint_learn((query_embedding,context_embedding, questions_embeddings, answers_embeddings),\
                        answer_reward, question_reward,\
                        (query_embedding, context_embedding_, questions_embeddings_, answers_embeddings_))

                    # evaluation
                    if (action == 0 or
                        (action == 1
                         and question_reward == cq_penalty)) and not stop:
                        stop = True
                        train_scores.append(answer_reward if action ==
                                            0 else 0)
                        if action == 0 and answer_reward == 1.0:
                            #train_correct.append(train_id)
                            pass
                        train_worse.append(1 if (action == 0 and answer_reward < float(1/args.topn) and question_reward == cq_reward) \
                            or (action == 1  and question_reward == cq_penalty) else 0)

                    if (base_action == 0 or
                        (base_action == 1
                         and question_reward == cq_penalty)) and not base_stop:
                        base_stop = True
                        train_base_scores.append(
                            answer_reward if base_action == 0 else 0)
                        if base_action == 0 and answer_reward == 1.0:
                            #train_base_correct.append(train_id)
                            pass
                        train_base_worse.append(1 if (base_action == 0 and answer_reward < float(1/args.topn) and question_reward == cq_reward) \
                            or (base_action == 1  and question_reward == cq_penalty) else 0)

                    if (score_action == 0 or
                        (score_action == 1 and question_reward
                         == cq_penalty)) and not score_stop:
                        score_stop = True
                        train_score_scores.append(
                            answer_reward if score_action == 0 else 0)
                        if score_action == 0 and answer_reward == 1.0:
                            pass
                            #train_score_correct.append(train_id)
                        train_score_worse.append(1 if (score_action == 0 and answer_reward < float(1/args.topn) and question_reward == cq_reward) \
                            or (score_action == 1  and question_reward == cq_penalty) else 0)

                    if (text_action == 0 or
                        (text_action == 1
                         and question_reward == cq_penalty)) and not text_stop:
                        text_stop = True
                        train_text_scores.append(
                            answer_reward if text_action == 0 else 0)
                        if text_action == 0 and answer_reward == 1.0:
                            pass
                            #train_text_correct.append(train_id)
                        train_text_worse.append(1 if (text_action == 0 and answer_reward < float(1/args.topn) and question_reward == cq_reward) \
                            or (text_action == 1  and question_reward == cq_penalty) else 0)

                    if n_round == 0:
                        train_q0_scores.append(answer_reward)
                        train_q0_worse.append(
                            1 if answer_reward < float(1 / args.topn)
                            and question_reward == cq_reward else 0)
                        if answer_reward == 1:
                            pass
                            #train_q0_correct.append(train_id)
                        if q_done:
                            train_q1_scores.append(0)
                            train_q2_scores.append(0)
                            train_q1_worse.append(1)
                            train_q2_worse.append(1)
                    elif n_round == 1:
                        train_q1_scores.append(answer_reward)
                        train_q1_worse.append(
                            1 if answer_reward < float(1 / args.topn)
                            and question_reward == cq_reward else 0)
                        if answer_reward == 1:
                            pass
                            #train_q1_correct.append(train_id)
                        if q_done:
                            train_q2_scores.append(0)
                            train_q2_worse.append(1)
                    elif n_round == 2:
                        train_q2_scores.append(answer_reward)
                        train_q2_worse.append(
                            1 if answer_reward < float(1 / args.topn)
                            and question_reward == cq_reward else 0)
                        if answer_reward == 1:
                            pass
                            #train_q2_correct.append(train_id)

                    context = context_
                    n_round += 1
                    total_toc = time.perf_counter()

            # save memory per batch
            if args.cv != -1:
                T.save(
                    memory,
                    args.dataset_name + '_experiments/embedding_cache/' +
                    args.reranker_name + '/' + str(args.cv) +
                    '/train/memory.batchsave' + str(batch_serial))
            else:
                T.save(
                    memory, args.dataset_name +
                    '_experiments/embedding_cache/' + args.reranker_name +
                    '/train/memory.batchsave' + str(batch_serial))

            del memory
            T.cuda.empty_cache()

        for oi in range(len(train_scores)):
            train_oracle_scores.append(
                max(train_q0_scores[oi], train_q1_scores[oi],
                    train_q2_scores[oi]))
            train_oracle_worse.append(
                min(train_q0_worse[oi], train_q1_worse[oi],
                    train_q2_worse[oi]))
        #train_oracle_correct = list(set(train_correct + train_q0_correct + train_q2_correct))

        print("Train epoch %.0f, acc %.6f, avgmrr %.6f, worse decisions %.6f" %
              (i, np.mean([1 if score == 1 else 0 for score in train_scores
                           ]), np.mean(train_scores), np.mean(train_worse)))
        print("q0 acc %.6f, avgmrr %.6f, worse decisions %.6f" %
              (np.mean([1 if score == 1 else 0 for score in train_q0_scores
                        ]), np.mean(train_q0_scores), np.mean(train_q0_worse)))
        print("q1 acc %.6f, avgmrr %.6f, worse decisions %.6f" %
              (np.mean([1 if score == 1 else 0 for score in train_q1_scores
                        ]), np.mean(train_q1_scores), np.mean(train_q1_worse)))
        print("q2 acc %.6f, avgmrr %.6f, worse decisions %.6f" %
              (np.mean([1 if score == 1 else 0 for score in train_q2_scores
                        ]), np.mean(train_q2_scores), np.mean(train_q2_worse)))
        print(
            "oracle acc %.6f, avgmrr %.6f, worse decisions %.6f" %
            (np.mean([1 if score == 1 else 0
                      for score in train_oracle_scores]),
             np.mean(train_oracle_scores), np.mean(train_oracle_worse)))
        print("base acc %.6f, avgmrr %.6f, worse decisions %.6f" %
              (np.mean([1 if score == 1 else 0
                        for score in train_base_scores]),
               np.mean(train_base_scores), np.mean(train_base_worse)))
        print(
            "score acc %.6f, avgmrr %.6f, worse decisions %.6f" %
            (np.mean([1 if score == 1 else 0 for score in train_score_scores]),
             np.mean(train_score_scores), np.mean(train_score_worse)))
        print("text acc %.6f, avgmrr %.6f, worse decisions %.6f" %
              (np.mean([1 if score == 1 else 0
                        for score in train_text_scores]),
               np.mean(train_text_scores), np.mean(train_text_worse)))
        '''
        print(train_correct)
        print(train_q0_correct)
        print(train_q1_correct)
        print(train_q2_correct)
        print(train_oracle_correct)
        print(train_base_correct)
        print(train_score_correct)
        print(train_text_correct)
        '''
        print("avg loss", np.mean(agent.loss_history))

        ## test
        test_scores, test_q0_scores, test_q1_scores, test_q2_scores, test_oracle_scores, test_base_scores, test_score_scores, test_text_scores = [],[],[],[],[],[],[],[]
        test_worse, test_q0_worse, test_q1_worse,test_q2_worse, test_oracle_worse, test_base_worse, test_score_worse, test_text_worse = [],[],[],[],[],[],[],[]
        #test_correct, test_q0_correct, test_q1_correct, test_q2_correct, test_oracle_correct, test_base_correct, test_score_correct, test_text_correct = [],[],[],[],[],[],[],[]
        # test the agent
        agent.epsilon = 0

        for batch_serial, batch in enumerate(test_dataset.batches):
            if args.cv != -1:
                if os.path.exists(args.dataset_name +
                                  '_experiments/embedding_cache/' +
                                  args.reranker_name + '/' + str(args.cv) +
                                  '/test/memory.batchsave' +
                                  str(batch_serial)):
                    with T.no_grad():
                        memory = T.load(args.dataset_name +
                                        '_experiments/embedding_cache/' +
                                        args.reranker_name + '/' +
                                        str(args.cv) +
                                        '/test/memory.batchsave' +
                                        str(batch_serial))
                else:
                    memory = {}
            else:
                if os.path.exists(args.dataset_name +
                                  '_experiments/embedding_cache/' +
                                  args.reranker_name +
                                  '/test/memory.batchsave' +
                                  str(batch_serial)):
                    with T.no_grad():
                        memory = T.load(args.dataset_name +
                                        '_experiments/embedding_cache/' +
                                        args.reranker_name +
                                        '/test/memory.batchsave' +
                                        str(batch_serial))
                else:
                    memory = {}
            test_ids = list(batch['conversations'].keys())
            user = User(batch['conversations'],
                        cq_reward=cq_reward,
                        cq_penalty=cq_penalty)
            for conv_serial, test_id in enumerate(test_ids):
                query = user.initialize_state(test_id)
                if query == '':  # UDC dataset has some weird stuff
                    continue
                context = ''
                ignore_questions = []
                n_round = 0
                patience_used = 0
                q_done = False
                stop, base_stop, score_stop, text_stop = False, False, False, False
                print(
                    '-------- test batch %.0f conversation %.0f/%.0f --------'
                    % (batch_serial, batch_size *
                       (batch_serial) + conv_serial + 1, test_size))
                while not q_done:
                    print('-------- round %.0f --------' % (n_round))
                    if query in memory.keys():
                        if context not in memory[query].keys():
                            # sampling
                            question_candidates = generate_batch_question_candidates(
                                batch, test_id, ignore_questions, batch_size)
                            answer_candidates = generate_batch_answer_candidates(
                                batch, test_id, batch_size)
                            # get reranker results
                            if args.reranker_name == 'Poly':
                                questions, questions_scores = rerank(
                                    question_reranker, query, context,
                                    question_candidates)
                                answers, answers_scores = rerank(
                                    answer_reranker, query, context,
                                    answer_candidates)
                            elif args.reranker_name == 'Bi':
                                questions, questions_scores = rerank(
                                    bi_question_reranker, query, context,
                                    question_candidates)
                                answers, answers_scores = rerank(
                                    bi_answer_reranker, query, context,
                                    answer_candidates)

                            memory = save_to_memory(query, context, memory,
                                                    questions, answers,
                                                    questions_scores,
                                                    answers_scores, tokenizer,
                                                    embedding_model)

                    else:
                        # sampling
                        question_candidates = generate_batch_question_candidates(
                            batch, test_id, ignore_questions, batch_size)
                        answer_candidates = generate_batch_answer_candidates(
                            batch, test_id, batch_size)

                        # get reranker results
                        if args.reranker_name == 'Poly':
                            questions, questions_scores = rerank(
                                question_reranker, query, context,
                                question_candidates)
                            answers, answers_scores = rerank(
                                answer_reranker, query, context,
                                answer_candidates)
                        elif args.reranker_name == 'Bi':
                            questions, questions_scores = rerank(
                                bi_question_reranker, query, context,
                                question_candidates)
                            answers, answers_scores = rerank(
                                bi_answer_reranker, query, context,
                                answer_candidates)

                        memory = save_to_memory(query, context, memory,
                                                questions, answers,
                                                questions_scores,
                                                answers_scores, tokenizer,
                                                embedding_model)

                    query_embedding, context_embedding, questions, answers, questions_embeddings, answers_embeddings, questions_scores, answers_scores = read_from_memory(
                        query, context, memory)
                    action = agent.choose_action(query_embedding,
                                                 context_embedding,
                                                 questions_embeddings,
                                                 answers_embeddings,
                                                 questions_scores,
                                                 answers_scores)
                    base_action = base_agent.choose_action(
                        query_embedding, context_embedding)
                    score_action = score_agent.choose_action(
                        questions_scores, answers_scores)
                    text_action = text_agent.choose_action(
                        query_embedding, context_embedding,
                        questions_embeddings, answers_embeddings)

                    #context_, question_reward, q_done, good_question, patience_this_turn = user.update_state(test_id, context, 1, questions, answers, use_top_k = args.topn - patience_used)
                    context_, question_reward, q_done, good_question, patience_this_turn = user.update_state(
                        test_id,
                        context,
                        1,
                        questions,
                        answers,
                        use_top_k=args.topn)
                    patience_used = max(patience_used + patience_this_turn,
                                        args.topn)
                    _, answer_reward, _, _, _ = user.update_state(
                        test_id,
                        context,
                        0,
                        questions,
                        answers,
                        use_top_k=args.topn - patience_used)
                    action_reward = [answer_reward, question_reward][action]
                    print('action', action, 'base_action', base_action,
                          'score_action', score_action, 'text_action',
                          text_action, 'answer reward', answer_reward,
                          'question reward', question_reward, 'q done', q_done)

                    if n_round >= max_round:
                        q_done = True

                    if not q_done:
                        ignore_questions.append(good_question)
                        if context_ not in memory[query].keys():
                            # sampling
                            question_candidates = generate_batch_question_candidates(
                                batch, test_id, ignore_questions, batch_size)
                            answer_candidates = generate_batch_answer_candidates(
                                batch, test_id, batch_size)
                            # get reranker results
                            if args.reranker_name == 'Poly':
                                questions_, questions_scores_ = rerank(
                                    question_reranker, query, context_,
                                    question_candidates)
                                answers_, answers_scores_ = rerank(
                                    answer_reranker, query, context_,
                                    answer_candidates)
                            elif args.reranker_name == 'Bi':
                                questions_, questions_scores_ = rerank(
                                    bi_question_reranker, query, context_,
                                    question_candidates)
                                answers_, answers_scores_ = rerank(
                                    bi_answer_reranker, query, context_,
                                    answer_candidates)

                            memory = save_to_memory(query, context_, memory,
                                                    questions_, answers_,
                                                    questions_scores_,
                                                    answers_scores_, tokenizer,
                                                    embedding_model)
                        query_embedding, context_embedding_, questions_, answers_, questions_embeddings_, answers_embeddings_, questions_scores_, answers_scores_ = read_from_memory(
                            query, context_, memory)

                    # evaluation
                    if (action == 0 or
                        (action == 1
                         and question_reward == cq_penalty)) and not stop:
                        stop = True
                        test_scores.append(answer_reward if action == 0 else 0)
                        if action == 0 and answer_reward == 1.0:
                            pass
                            #test_correct.append(test_id)
                        test_worse.append(1 if (action == 0 and answer_reward < float(1/args.topn) and question_reward == cq_reward) \
                            or (action == 1  and question_reward == cq_penalty) else 0)

                    if (base_action == 0 or
                        (base_action == 1
                         and question_reward == cq_penalty)) and not base_stop:
                        base_stop = True
                        test_base_scores.append(answer_reward if base_action ==
                                                0 else 0)
                        if base_action == 0 and answer_reward == 1.0:
                            pass
                            #test_base_correct.append(test_id)
                        test_base_worse.append(1 if (base_action == 0 and answer_reward < float(1/args.topn) and question_reward == cq_reward) \
                            or (base_action == 1  and question_reward == cq_penalty) else 0)

                    if (score_action == 0 or
                        (score_action == 1 and question_reward
                         == cq_penalty)) and not score_stop:
                        score_stop = True
                        test_score_scores.append(
                            answer_reward if score_action == 0 else 0)
                        if score_action == 0 and answer_reward == 1.0:
                            pass
                            #test_score_correct.append(test_id)
                        test_score_worse.append(1 if (score_action == 0 and answer_reward < float(1/args.topn) and question_reward == cq_reward) \
                            or (score_action == 1  and question_reward == cq_penalty) else 0)

                    if (text_action == 0 or
                        (text_action == 1
                         and question_reward == cq_penalty)) and not text_stop:
                        text_stop = True
                        test_text_scores.append(answer_reward if text_action ==
                                                0 else 0)
                        if text_action == 0 and answer_reward == 1.0:
                            pass
                            #test_text_correct.append(test_id)
                        test_text_worse.append(1 if (text_action == 0 and answer_reward < float(1/args.topn) and question_reward == cq_reward) \
                            or (text_action == 1  and question_reward == cq_penalty) else 0)

                    if n_round == 0:
                        test_q0_scores.append(answer_reward)
                        test_q0_worse.append(
                            1 if answer_reward < float(1 / args.topn)
                            and question_reward == cq_reward else 0)
                        if answer_reward == 1:
                            pass
                            #test_q0_correct.append(test_id)
                        if q_done:
                            test_q1_scores.append(0)
                            test_q2_scores.append(0)
                            test_q1_worse.append(1)
                            test_q2_worse.append(1)
                    elif n_round == 1:
                        test_q1_scores.append(answer_reward)
                        test_q1_worse.append(
                            1 if answer_reward < float(1 / args.topn)
                            and question_reward == cq_reward else 0)
                        if answer_reward == 1:
                            pass
                            #test_q1_correct.append(test_id)
                        if q_done:
                            test_q2_scores.append(0)
                            test_q2_worse.append(1)
                    elif n_round == 2:
                        test_q2_scores.append(answer_reward)
                        test_q2_worse.append(
                            1 if answer_reward < float(1 / args.topn)
                            and question_reward == cq_reward else 0)
                        if answer_reward == 1:
                            pass
                            #test_q2_correct.append(test_id)

                    n_round += 1
                    context = context_

            # save batch cache
            if args.cv != -1:
                T.save(
                    memory,
                    args.dataset_name + '_experiments/embedding_cache/' +
                    args.reranker_name + '/' + str(args.cv) +
                    '/test/memory.batchsave' + str(batch_serial))
            else:
                T.save(
                    memory, args.dataset_name +
                    '_experiments/embedding_cache/' + args.reranker_name +
                    '/test/memory.batchsave' + str(batch_serial))

            del memory
            T.cuda.empty_cache()

        for oi in range(len(test_scores)):
            test_oracle_scores.append(
                max(test_q0_scores[oi], test_q1_scores[oi],
                    test_q2_scores[oi]))
            test_oracle_worse.append(
                min(test_q0_worse[oi], test_q1_worse[oi], test_q2_worse[oi]))
        #test_oracle_correct = list(set(test_correct + test_q0_correct + test_q2_correct))

        print("Test epoch %.0f, acc %.6f, avgmrr %.6f, worse decisions %.6f" %
              (i, np.mean([1 if score == 1 else 0 for score in test_scores
                           ]), np.mean(test_scores), np.mean(test_worse)))
        print("q0 acc %.6f, avgmrr %.6f, worse decisions %.6f" %
              (np.mean([1 if score == 1 else 0 for score in test_q0_scores
                        ]), np.mean(test_q0_scores), np.mean(test_q0_worse)))
        print("q1 acc %.6f, avgmrr %.6f, worse decisions %.6f" %
              (np.mean([1 if score == 1 else 0 for score in test_q1_scores
                        ]), np.mean(test_q1_scores), np.mean(test_q1_worse)))
        print("q2 acc %.6f, avgmrr %.6f, worse decisions %.6f" %
              (np.mean([1 if score == 1 else 0 for score in test_q2_scores
                        ]), np.mean(test_q2_scores), np.mean(test_q2_worse)))
        print(
            "oracle acc %.6f, avgmrr %.6f, worse decisions %.6f" %
            (np.mean([1 if score == 1 else 0 for score in test_oracle_scores]),
             np.mean(test_oracle_scores), np.mean(test_oracle_worse)))
        print("base acc %.6f, avgmrr %.6f, worse decisions %.6f" %
              (np.mean([1 if score == 1 else 0 for score in test_base_scores]),
               np.mean(test_base_scores), np.mean(test_base_worse)))
        print("score acc %.6f, avgmrr %.6f, worse decisions %.6f" %
              (np.mean([1 if score == 1 else 0
                        for score in test_score_scores]),
               np.mean(test_score_scores), np.mean(test_score_worse)))
        print("text acc %.6f, avgmrr %.6f, worse decisions %.6f" %
              (np.mean([1 if score == 1 else 0 for score in test_text_scores]),
               np.mean(test_text_scores), np.mean(test_text_worse)))
        '''