def _run_test_repeat(self, tmpdir: str, fake_input: FakeInput): outfile = os.path.join(tmpdir, 'log.jsonl') Interactive.main(model='repeat_query', outfile=outfile) log = conversations.Conversations(outfile) self.assertEqual(len(log), fake_input.max_episodes) for entry in log: self.assertEqual(len(entry), 2 * fake_input.max_turns)
def test_constrained_beam_search(self): # call it with particular args Interactive.main( model='transformer/generator', task='blended_skill_talk', include_personas=False, include_initial_utterances=False, #history_size=2, beam_size=4, beam_min_length=20, beam_context_block_ngram=0, beam_block_ngram=3, inference='constrainedbeam', #inference='beam', model_parallel=True, # the model_file is a filename path pointing to a particular model dump. # Model files that begin with "zoo:" are special files distributed by the ParlAI team. # They'll be automatically downloaded when you ask to use them. #model_file = "zoo:blender/blender_3B/model" model_file="zoo:blender/blender_90M/model")
'w') as qrerankerinput: for rq in retrieved_question: qrerankerinput.write(rq) qrerankerinput.write('\n') with open( '/raid/zhenduow/conversationalQA/data/' + exp_name + '/retrieved_answer', 'w') as arerankerinput: for ra in retrieved_answer: arerankerinput.write(ra) arerankerinput.write('\n') # get reranker results question = Interactive.main(model = 'transformer/polyencoder', \ model_file = 'zoo:pretrained_transformers/model_poly/msdialogquestion', \ encode_candidate_vecs = 'true', eval_candidates ='fixed', \ fixed_candidates_path = '/raid/zhenduow/conversationalQA/data/' + exp_name + '/retrieved_question', \ force_fp16_tokens = True,\ human_input=obs, fixed_candidate_vecs = 'replace') answer = Interactive.main(model = 'transformer/polyencoder', \ model_file = 'zoo:pretrained_transformers/model_poly/msdialoganswer', \ encode_candidate_vecs = 'true', eval_candidates ='fixed', \ fixed_candidates_path = '/raid/zhenduow/conversationalQA/data/' + exp_name + '/retrieved_answer', \ human_input=obs, fixed_candidate_vecs = 'replace') memory[obs] = [question, answer] action = agent.choose_action(obs, question, answer) obs_, question_reward, q_done, good_question = user.update_state( conv_id, obs, 1, question, answer, use_top_k=use_top_k) _, answer_reward, _, _ = user.update_state(conv_id, obs,
def test_repeat(self): Interactive.main(model='repeat_query', task='convai2', datatype='valid')
def test_repeat(self): Interactive.main(model='repeat_query')
def main(): model_path = '/home/warvisionary/parlai_transfer/from_pretrained/model' print(f"Interacting with model at: {model_path}") Interactive.main(task='empathetic_dialogues_ru', model_file=model_path)
#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """ Basic example which allows local human keyboard input to talk to a trained model. For documentation, see parlai.scripts.interactive. """ import random from parlai.scripts.interactive import Interactive if __name__ == '__main__': random.seed(42) Interactive.main()
# Import the Interactive script from parlai.scripts.interactive import Interactive # call it with particular args Interactive.main( # the model_file is a filename path pointing to a particular model dump. # Model files that begin with "zoo:" are special files distributed by the ParlAI team. # They'll be automatically downloaded when you ask to use them. model_file='zoo:tutorial_transformer_generator/model')
def main(args): logging.getLogger().setLevel(logging.INFO) limit_memory(1e11) random.seed(2020) if args.cv != -1: train_dataset = ConversationDataset( 'data/' + args.dataset_name + '-Complete/train' + str(args.cv) + '/', batch_size, max_train_size) test_dataset = ConversationDataset( 'data/' + args.dataset_name + '-Complete/test' + str(args.cv) + '/', batch_size, max_test_size) else: train_dataset = ConversationDataset( 'data/' + args.dataset_name + '-Complete/train/', batch_size, max_train_size) test_dataset = ConversationDataset( 'data/' + args.dataset_name + '-Complete/test/', batch_size, max_test_size) train_size = sum( [len(b['conversations'].keys()) for b in train_dataset.batches]) test_size = sum( [len(b['conversations'].keys()) for b in test_dataset.batches]) agent = Agent(lr=1e-4, input_dims=(3 + args.topn) * observation_dim + 1 + args.topn, top_k=args.topn, n_actions=action_num, gamma=agent_gamma, weight_decay=0.01) score_agent = ScoreAgent(lr=1e-4, input_dims=1 + args.topn, top_k=args.topn, n_actions=action_num, gamma=agent_gamma, weight_decay=0.0) text_agent = TextAgent(lr=1e-4, input_dims=(3 + args.topn) * observation_dim, top_k=args.topn, n_actions=action_num, gamma=agent_gamma, weight_decay=0.01) base_agent = BaseAgent(lr=1e-4, input_dims=2 * observation_dim, n_actions=2, weight_decay=0.01) if args.dataset_name == 'MSDialog': reranker_prefix = '' elif args.dataset_name == 'UDC': reranker_prefix = 'udc' elif args.dataset_name == 'Opendialkg': reranker_prefix = 'open' # create rerankers if args.reranker_name == 'Poly': question_reranker = Interactive.main(model = 'transformer/polyencoder', \ model_file = 'zoo:pretrained_transformers/model_poly/' + reranker_prefix + 'question', \ encode_candidate_vecs = False, eval_candidates = 'inline', interactive_candidates = 'inline', return_cand_scores = True) answer_reranker = Interactive.main(model = 'transformer/polyencoder', \ model_file = 'zoo:pretrained_transformers/model_poly/' + reranker_prefix + 'answer', \ encode_candidate_vecs = False, eval_candidates = 'inline', interactive_candidates = 'inline', return_cand_scores = True) print("Loading rerankers:", 'model_poly/' + reranker_prefix + 'answer', 'model_poly/' + reranker_prefix + 'question') elif args.reranker_name == 'Bi': bi_question_reranker = Interactive.main(model = 'transformer/biencoder', \ model_file = 'zoo:pretrained_transformers/model_bi/' + reranker_prefix + 'question', \ encode_candidate_vecs = False, eval_candidates = 'inline', interactive_candidates = 'inline', return_cand_scores = True) bi_answer_reranker = Interactive.main(model = 'transformer/biencoder', \ model_file = 'zoo:pretrained_transformers/model_bi/' + reranker_prefix + 'answer', \ encode_candidate_vecs = False, eval_candidates = 'inline', interactive_candidates = 'inline', return_cand_scores = True) print("Loading rerankers:", 'model_bi/' + reranker_prefix + 'answer', 'model_bi/' + reranker_prefix + 'question') # embedding model tokenizer = AutoTokenizer.from_pretrained('xlnet-base-cased') embedding_model = AutoModel.from_pretrained('xlnet-base-cased') ''' local_vars = list(locals().items()) for var, obj in local_vars: print(var, get_size(obj)) ''' if not os.path.exists(args.dataset_name + '_experiments/embedding_cache/'): os.makedirs(args.dataset_name + '_experiments/embedding_cache/') if not os.path.exists(args.dataset_name + '_experiments/embedding_cache/' + args.reranker_name): os.makedirs(args.dataset_name + '_experiments/embedding_cache/' + args.reranker_name) if args.cv != -1: if not os.path.exists(args.dataset_name + '_experiments/embedding_cache/' + args.reranker_name + '/' + str(args.cv)): os.makedirs(args.dataset_name + '_experiments/embedding_cache/' + args.reranker_name + '/' + str(args.cv)) os.makedirs(args.dataset_name + '_experiments/embedding_cache/' + args.reranker_name + '/' + str(args.cv) + '/train') os.makedirs(args.dataset_name + '_experiments/embedding_cache/' + args.reranker_name + '/' + str(args.cv) + '/test') else: if not os.path.exists(args.dataset_name + '_experiments/embedding_cache/' + args.reranker_name + '/train'): os.makedirs(args.dataset_name + '_experiments/embedding_cache/' + args.reranker_name + '/train') if not os.path.exists(args.dataset_name + '_experiments/embedding_cache/' + args.reranker_name + '/test'): os.makedirs(args.dataset_name + '_experiments/embedding_cache/' + args.reranker_name + '/test') for i in range(train_iter): train_scores, train_q0_scores, train_q1_scores, train_q2_scores, train_oracle_scores, train_base_scores, train_score_scores, train_text_scores = [],[],[],[],[],[],[],[] train_worse, train_q0_worse, train_q1_worse, train_q2_worse, train_oracle_worse, train_base_worse, train_score_worse, train_text_worse = [],[],[],[],[],[],[],[] #train_correct, train_q0_correct, train_q1_correct, train_q2_correct, train_oracle_correct, train_base_correct, train_score_correct,train_text_correct = [],[],[],[],[],[],[],[] for batch_serial, batch in enumerate(train_dataset.batches): print(dict(psutil.virtual_memory()._asdict())) if args.cv != -1: if os.path.exists(args.dataset_name + '_experiments/embedding_cache/' + args.reranker_name + '/' + str(args.cv) + '/train/memory.batchsave' + str(batch_serial)): with T.no_grad(): memory = T.load(args.dataset_name + '_experiments/embedding_cache/' + args.reranker_name + '/' + str(args.cv) + '/train/memory.batchsave' + str(batch_serial)) else: memory = {} else: if os.path.exists(args.dataset_name + '_experiments/embedding_cache/' + args.reranker_name + '/train/memory.batchsave' + str(batch_serial)): with T.no_grad(): memory = T.load(args.dataset_name + '_experiments/embedding_cache/' + args.reranker_name + '/train/memory.batchsave' + str(batch_serial)) else: memory = {} train_ids = list(batch['conversations'].keys()) user = User(batch['conversations'], cq_reward=cq_reward, cq_penalty=cq_penalty) for conv_serial, train_id in enumerate(train_ids): query = user.initialize_state(train_id) if query == '': # UDC dataset has some weird stuff continue context = '' ignore_questions = [] n_round = 0 patience_used = 0 q_done = False stop, base_stop, score_stop, text_stop = False, False, False, False print( '-------- train batch %.0f conversation %.0f/%.0f --------' % (batch_serial, batch_size * (batch_serial) + conv_serial + 1, train_size)) while not q_done: total_tic = time.perf_counter() print('-------- round %.0f --------' % (n_round)) if query in memory.keys(): if context not in memory[query].keys(): # sampling question_candidates = generate_batch_question_candidates( batch, train_id, ignore_questions, batch_size) answer_candidates = generate_batch_answer_candidates( batch, train_id, batch_size) # get reranker results if args.reranker_name == 'Poly': questions, questions_scores = rerank( question_reranker, query, context, question_candidates) answers, answers_scores = rerank( answer_reranker, query, context, answer_candidates) elif args.reranker_name == 'Bi': questions, questions_scores = rerank( bi_question_reranker, query, context, question_candidates) answers, answers_scores = rerank( bi_answer_reranker, query, context, answer_candidates) memory = save_to_memory(query, context, memory, questions, answers, questions_scores, answers_scores, tokenizer, embedding_model) else: # sampling question_candidates = generate_batch_question_candidates( batch, train_id, ignore_questions, batch_size) answer_candidates = generate_batch_answer_candidates( batch, train_id, batch_size) # get reranker results if args.reranker_name == 'Poly': questions, questions_scores = rerank( question_reranker, query, context, question_candidates) answers, answers_scores = rerank( answer_reranker, query, context, answer_candidates) elif args.reranker_name == 'Bi': questions, questions_scores = rerank( bi_question_reranker, query, context, question_candidates) answers, answers_scores = rerank( bi_answer_reranker, query, context, answer_candidates) memory = save_to_memory(query, context, memory, questions, answers, questions_scores, answers_scores, tokenizer, embedding_model) query_embedding, context_embedding, questions, answers, questions_embeddings, answers_embeddings, questions_scores, answers_scores = read_from_memory( query, context, memory) action = agent.choose_action(query_embedding, context_embedding, questions_embeddings, answers_embeddings, questions_scores, answers_scores) base_action = base_agent.choose_action( query_embedding, context_embedding) score_action = score_agent.choose_action( questions_scores, answers_scores) text_action = text_agent.choose_action( query_embedding, context_embedding, questions_embeddings, answers_embeddings) evaluation_tic = time.perf_counter() #context_, question_reward, q_done, good_question, patience_this_turn = user.update_state(train_id, context, 1, questions, answers, use_top_k = args.topn - patience_used) context_, question_reward, q_done, good_question, patience_this_turn = user.update_state( train_id, context, 1, questions, answers, use_top_k=args.topn) patience_used = max(patience_used + patience_this_turn, args.topn) _, answer_reward, _, _, _ = user.update_state( train_id, context, 0, questions, answers, use_top_k=args.topn - patience_used) action_reward = [answer_reward, question_reward][action] evaluation_toc = time.perf_counter() print('action', action, 'base_action', base_action, 'score_action', score_action, 'text_action', text_action, 'answer reward', answer_reward, 'question reward', question_reward, 'q done', q_done) if n_round >= max_round: q_done = True if not q_done: ignore_questions.append(good_question) if context_ not in memory[query].keys(): # sampling question_candidates = generate_batch_question_candidates( batch, train_id, ignore_questions, batch_size) answer_candidates = generate_batch_answer_candidates( batch, train_id, batch_size) # get reranker results if args.reranker_name == 'Poly': questions_, questions_scores_ = rerank( question_reranker, query, context_, question_candidates) answers_, answers_scores_ = rerank( answer_reranker, query, context_, answer_candidates) elif args.reranker_name == 'Bi': questions_, questions_scores_ = rerank( bi_question_reranker, query, context_, question_candidates) answers_, answers_scores_ = rerank( bi_answer_reranker, query, context_, answer_candidates) memory = save_to_memory(query, context_, memory, questions_, answers_, questions_scores_, answers_scores_, tokenizer, embedding_model) query_embedding, context_embedding_, questions_, answers_, questions_embeddings_, answers_embeddings_, questions_scores_, answers_scores_ = read_from_memory( query, context_, memory) else: context_embedding_ = generate_embedding_no_grad( context_, tokenizer, embedding_model) questions_, answers_, questions_embeddings_, answers_embeddings_, questions_scores_, answers_scores_ = None, None, None, None, None, None agent.joint_learn((query_embedding, context_embedding, questions_embeddings, answers_embeddings, questions_scores, answers_scores),\ answer_reward, question_reward,\ (query_embedding, context_embedding_, questions_embeddings_, answers_embeddings_, questions_scores_, answers_scores_)) base_agent.learn( query_embedding, context_embedding, 0 if (n_round + 1) == len(user.dataset[train_id]) / 2 else 1) score_agent.joint_learn((questions_scores, answers_scores),\ answer_reward, question_reward,\ (questions_scores_, answers_scores_)) text_agent.joint_learn((query_embedding,context_embedding, questions_embeddings, answers_embeddings),\ answer_reward, question_reward,\ (query_embedding, context_embedding_, questions_embeddings_, answers_embeddings_)) # evaluation if (action == 0 or (action == 1 and question_reward == cq_penalty)) and not stop: stop = True train_scores.append(answer_reward if action == 0 else 0) if action == 0 and answer_reward == 1.0: #train_correct.append(train_id) pass train_worse.append(1 if (action == 0 and answer_reward < float(1/args.topn) and question_reward == cq_reward) \ or (action == 1 and question_reward == cq_penalty) else 0) if (base_action == 0 or (base_action == 1 and question_reward == cq_penalty)) and not base_stop: base_stop = True train_base_scores.append( answer_reward if base_action == 0 else 0) if base_action == 0 and answer_reward == 1.0: #train_base_correct.append(train_id) pass train_base_worse.append(1 if (base_action == 0 and answer_reward < float(1/args.topn) and question_reward == cq_reward) \ or (base_action == 1 and question_reward == cq_penalty) else 0) if (score_action == 0 or (score_action == 1 and question_reward == cq_penalty)) and not score_stop: score_stop = True train_score_scores.append( answer_reward if score_action == 0 else 0) if score_action == 0 and answer_reward == 1.0: pass #train_score_correct.append(train_id) train_score_worse.append(1 if (score_action == 0 and answer_reward < float(1/args.topn) and question_reward == cq_reward) \ or (score_action == 1 and question_reward == cq_penalty) else 0) if (text_action == 0 or (text_action == 1 and question_reward == cq_penalty)) and not text_stop: text_stop = True train_text_scores.append( answer_reward if text_action == 0 else 0) if text_action == 0 and answer_reward == 1.0: pass #train_text_correct.append(train_id) train_text_worse.append(1 if (text_action == 0 and answer_reward < float(1/args.topn) and question_reward == cq_reward) \ or (text_action == 1 and question_reward == cq_penalty) else 0) if n_round == 0: train_q0_scores.append(answer_reward) train_q0_worse.append( 1 if answer_reward < float(1 / args.topn) and question_reward == cq_reward else 0) if answer_reward == 1: pass #train_q0_correct.append(train_id) if q_done: train_q1_scores.append(0) train_q2_scores.append(0) train_q1_worse.append(1) train_q2_worse.append(1) elif n_round == 1: train_q1_scores.append(answer_reward) train_q1_worse.append( 1 if answer_reward < float(1 / args.topn) and question_reward == cq_reward else 0) if answer_reward == 1: pass #train_q1_correct.append(train_id) if q_done: train_q2_scores.append(0) train_q2_worse.append(1) elif n_round == 2: train_q2_scores.append(answer_reward) train_q2_worse.append( 1 if answer_reward < float(1 / args.topn) and question_reward == cq_reward else 0) if answer_reward == 1: pass #train_q2_correct.append(train_id) context = context_ n_round += 1 total_toc = time.perf_counter() # save memory per batch if args.cv != -1: T.save( memory, args.dataset_name + '_experiments/embedding_cache/' + args.reranker_name + '/' + str(args.cv) + '/train/memory.batchsave' + str(batch_serial)) else: T.save( memory, args.dataset_name + '_experiments/embedding_cache/' + args.reranker_name + '/train/memory.batchsave' + str(batch_serial)) del memory T.cuda.empty_cache() for oi in range(len(train_scores)): train_oracle_scores.append( max(train_q0_scores[oi], train_q1_scores[oi], train_q2_scores[oi])) train_oracle_worse.append( min(train_q0_worse[oi], train_q1_worse[oi], train_q2_worse[oi])) #train_oracle_correct = list(set(train_correct + train_q0_correct + train_q2_correct)) print("Train epoch %.0f, acc %.6f, avgmrr %.6f, worse decisions %.6f" % (i, np.mean([1 if score == 1 else 0 for score in train_scores ]), np.mean(train_scores), np.mean(train_worse))) print("q0 acc %.6f, avgmrr %.6f, worse decisions %.6f" % (np.mean([1 if score == 1 else 0 for score in train_q0_scores ]), np.mean(train_q0_scores), np.mean(train_q0_worse))) print("q1 acc %.6f, avgmrr %.6f, worse decisions %.6f" % (np.mean([1 if score == 1 else 0 for score in train_q1_scores ]), np.mean(train_q1_scores), np.mean(train_q1_worse))) print("q2 acc %.6f, avgmrr %.6f, worse decisions %.6f" % (np.mean([1 if score == 1 else 0 for score in train_q2_scores ]), np.mean(train_q2_scores), np.mean(train_q2_worse))) print( "oracle acc %.6f, avgmrr %.6f, worse decisions %.6f" % (np.mean([1 if score == 1 else 0 for score in train_oracle_scores]), np.mean(train_oracle_scores), np.mean(train_oracle_worse))) print("base acc %.6f, avgmrr %.6f, worse decisions %.6f" % (np.mean([1 if score == 1 else 0 for score in train_base_scores]), np.mean(train_base_scores), np.mean(train_base_worse))) print( "score acc %.6f, avgmrr %.6f, worse decisions %.6f" % (np.mean([1 if score == 1 else 0 for score in train_score_scores]), np.mean(train_score_scores), np.mean(train_score_worse))) print("text acc %.6f, avgmrr %.6f, worse decisions %.6f" % (np.mean([1 if score == 1 else 0 for score in train_text_scores]), np.mean(train_text_scores), np.mean(train_text_worse))) ''' print(train_correct) print(train_q0_correct) print(train_q1_correct) print(train_q2_correct) print(train_oracle_correct) print(train_base_correct) print(train_score_correct) print(train_text_correct) ''' print("avg loss", np.mean(agent.loss_history)) ## test test_scores, test_q0_scores, test_q1_scores, test_q2_scores, test_oracle_scores, test_base_scores, test_score_scores, test_text_scores = [],[],[],[],[],[],[],[] test_worse, test_q0_worse, test_q1_worse,test_q2_worse, test_oracle_worse, test_base_worse, test_score_worse, test_text_worse = [],[],[],[],[],[],[],[] #test_correct, test_q0_correct, test_q1_correct, test_q2_correct, test_oracle_correct, test_base_correct, test_score_correct, test_text_correct = [],[],[],[],[],[],[],[] # test the agent agent.epsilon = 0 for batch_serial, batch in enumerate(test_dataset.batches): if args.cv != -1: if os.path.exists(args.dataset_name + '_experiments/embedding_cache/' + args.reranker_name + '/' + str(args.cv) + '/test/memory.batchsave' + str(batch_serial)): with T.no_grad(): memory = T.load(args.dataset_name + '_experiments/embedding_cache/' + args.reranker_name + '/' + str(args.cv) + '/test/memory.batchsave' + str(batch_serial)) else: memory = {} else: if os.path.exists(args.dataset_name + '_experiments/embedding_cache/' + args.reranker_name + '/test/memory.batchsave' + str(batch_serial)): with T.no_grad(): memory = T.load(args.dataset_name + '_experiments/embedding_cache/' + args.reranker_name + '/test/memory.batchsave' + str(batch_serial)) else: memory = {} test_ids = list(batch['conversations'].keys()) user = User(batch['conversations'], cq_reward=cq_reward, cq_penalty=cq_penalty) for conv_serial, test_id in enumerate(test_ids): query = user.initialize_state(test_id) if query == '': # UDC dataset has some weird stuff continue context = '' ignore_questions = [] n_round = 0 patience_used = 0 q_done = False stop, base_stop, score_stop, text_stop = False, False, False, False print( '-------- test batch %.0f conversation %.0f/%.0f --------' % (batch_serial, batch_size * (batch_serial) + conv_serial + 1, test_size)) while not q_done: print('-------- round %.0f --------' % (n_round)) if query in memory.keys(): if context not in memory[query].keys(): # sampling question_candidates = generate_batch_question_candidates( batch, test_id, ignore_questions, batch_size) answer_candidates = generate_batch_answer_candidates( batch, test_id, batch_size) # get reranker results if args.reranker_name == 'Poly': questions, questions_scores = rerank( question_reranker, query, context, question_candidates) answers, answers_scores = rerank( answer_reranker, query, context, answer_candidates) elif args.reranker_name == 'Bi': questions, questions_scores = rerank( bi_question_reranker, query, context, question_candidates) answers, answers_scores = rerank( bi_answer_reranker, query, context, answer_candidates) memory = save_to_memory(query, context, memory, questions, answers, questions_scores, answers_scores, tokenizer, embedding_model) else: # sampling question_candidates = generate_batch_question_candidates( batch, test_id, ignore_questions, batch_size) answer_candidates = generate_batch_answer_candidates( batch, test_id, batch_size) # get reranker results if args.reranker_name == 'Poly': questions, questions_scores = rerank( question_reranker, query, context, question_candidates) answers, answers_scores = rerank( answer_reranker, query, context, answer_candidates) elif args.reranker_name == 'Bi': questions, questions_scores = rerank( bi_question_reranker, query, context, question_candidates) answers, answers_scores = rerank( bi_answer_reranker, query, context, answer_candidates) memory = save_to_memory(query, context, memory, questions, answers, questions_scores, answers_scores, tokenizer, embedding_model) query_embedding, context_embedding, questions, answers, questions_embeddings, answers_embeddings, questions_scores, answers_scores = read_from_memory( query, context, memory) action = agent.choose_action(query_embedding, context_embedding, questions_embeddings, answers_embeddings, questions_scores, answers_scores) base_action = base_agent.choose_action( query_embedding, context_embedding) score_action = score_agent.choose_action( questions_scores, answers_scores) text_action = text_agent.choose_action( query_embedding, context_embedding, questions_embeddings, answers_embeddings) #context_, question_reward, q_done, good_question, patience_this_turn = user.update_state(test_id, context, 1, questions, answers, use_top_k = args.topn - patience_used) context_, question_reward, q_done, good_question, patience_this_turn = user.update_state( test_id, context, 1, questions, answers, use_top_k=args.topn) patience_used = max(patience_used + patience_this_turn, args.topn) _, answer_reward, _, _, _ = user.update_state( test_id, context, 0, questions, answers, use_top_k=args.topn - patience_used) action_reward = [answer_reward, question_reward][action] print('action', action, 'base_action', base_action, 'score_action', score_action, 'text_action', text_action, 'answer reward', answer_reward, 'question reward', question_reward, 'q done', q_done) if n_round >= max_round: q_done = True if not q_done: ignore_questions.append(good_question) if context_ not in memory[query].keys(): # sampling question_candidates = generate_batch_question_candidates( batch, test_id, ignore_questions, batch_size) answer_candidates = generate_batch_answer_candidates( batch, test_id, batch_size) # get reranker results if args.reranker_name == 'Poly': questions_, questions_scores_ = rerank( question_reranker, query, context_, question_candidates) answers_, answers_scores_ = rerank( answer_reranker, query, context_, answer_candidates) elif args.reranker_name == 'Bi': questions_, questions_scores_ = rerank( bi_question_reranker, query, context_, question_candidates) answers_, answers_scores_ = rerank( bi_answer_reranker, query, context_, answer_candidates) memory = save_to_memory(query, context_, memory, questions_, answers_, questions_scores_, answers_scores_, tokenizer, embedding_model) query_embedding, context_embedding_, questions_, answers_, questions_embeddings_, answers_embeddings_, questions_scores_, answers_scores_ = read_from_memory( query, context_, memory) # evaluation if (action == 0 or (action == 1 and question_reward == cq_penalty)) and not stop: stop = True test_scores.append(answer_reward if action == 0 else 0) if action == 0 and answer_reward == 1.0: pass #test_correct.append(test_id) test_worse.append(1 if (action == 0 and answer_reward < float(1/args.topn) and question_reward == cq_reward) \ or (action == 1 and question_reward == cq_penalty) else 0) if (base_action == 0 or (base_action == 1 and question_reward == cq_penalty)) and not base_stop: base_stop = True test_base_scores.append(answer_reward if base_action == 0 else 0) if base_action == 0 and answer_reward == 1.0: pass #test_base_correct.append(test_id) test_base_worse.append(1 if (base_action == 0 and answer_reward < float(1/args.topn) and question_reward == cq_reward) \ or (base_action == 1 and question_reward == cq_penalty) else 0) if (score_action == 0 or (score_action == 1 and question_reward == cq_penalty)) and not score_stop: score_stop = True test_score_scores.append( answer_reward if score_action == 0 else 0) if score_action == 0 and answer_reward == 1.0: pass #test_score_correct.append(test_id) test_score_worse.append(1 if (score_action == 0 and answer_reward < float(1/args.topn) and question_reward == cq_reward) \ or (score_action == 1 and question_reward == cq_penalty) else 0) if (text_action == 0 or (text_action == 1 and question_reward == cq_penalty)) and not text_stop: text_stop = True test_text_scores.append(answer_reward if text_action == 0 else 0) if text_action == 0 and answer_reward == 1.0: pass #test_text_correct.append(test_id) test_text_worse.append(1 if (text_action == 0 and answer_reward < float(1/args.topn) and question_reward == cq_reward) \ or (text_action == 1 and question_reward == cq_penalty) else 0) if n_round == 0: test_q0_scores.append(answer_reward) test_q0_worse.append( 1 if answer_reward < float(1 / args.topn) and question_reward == cq_reward else 0) if answer_reward == 1: pass #test_q0_correct.append(test_id) if q_done: test_q1_scores.append(0) test_q2_scores.append(0) test_q1_worse.append(1) test_q2_worse.append(1) elif n_round == 1: test_q1_scores.append(answer_reward) test_q1_worse.append( 1 if answer_reward < float(1 / args.topn) and question_reward == cq_reward else 0) if answer_reward == 1: pass #test_q1_correct.append(test_id) if q_done: test_q2_scores.append(0) test_q2_worse.append(1) elif n_round == 2: test_q2_scores.append(answer_reward) test_q2_worse.append( 1 if answer_reward < float(1 / args.topn) and question_reward == cq_reward else 0) if answer_reward == 1: pass #test_q2_correct.append(test_id) n_round += 1 context = context_ # save batch cache if args.cv != -1: T.save( memory, args.dataset_name + '_experiments/embedding_cache/' + args.reranker_name + '/' + str(args.cv) + '/test/memory.batchsave' + str(batch_serial)) else: T.save( memory, args.dataset_name + '_experiments/embedding_cache/' + args.reranker_name + '/test/memory.batchsave' + str(batch_serial)) del memory T.cuda.empty_cache() for oi in range(len(test_scores)): test_oracle_scores.append( max(test_q0_scores[oi], test_q1_scores[oi], test_q2_scores[oi])) test_oracle_worse.append( min(test_q0_worse[oi], test_q1_worse[oi], test_q2_worse[oi])) #test_oracle_correct = list(set(test_correct + test_q0_correct + test_q2_correct)) print("Test epoch %.0f, acc %.6f, avgmrr %.6f, worse decisions %.6f" % (i, np.mean([1 if score == 1 else 0 for score in test_scores ]), np.mean(test_scores), np.mean(test_worse))) print("q0 acc %.6f, avgmrr %.6f, worse decisions %.6f" % (np.mean([1 if score == 1 else 0 for score in test_q0_scores ]), np.mean(test_q0_scores), np.mean(test_q0_worse))) print("q1 acc %.6f, avgmrr %.6f, worse decisions %.6f" % (np.mean([1 if score == 1 else 0 for score in test_q1_scores ]), np.mean(test_q1_scores), np.mean(test_q1_worse))) print("q2 acc %.6f, avgmrr %.6f, worse decisions %.6f" % (np.mean([1 if score == 1 else 0 for score in test_q2_scores ]), np.mean(test_q2_scores), np.mean(test_q2_worse))) print( "oracle acc %.6f, avgmrr %.6f, worse decisions %.6f" % (np.mean([1 if score == 1 else 0 for score in test_oracle_scores]), np.mean(test_oracle_scores), np.mean(test_oracle_worse))) print("base acc %.6f, avgmrr %.6f, worse decisions %.6f" % (np.mean([1 if score == 1 else 0 for score in test_base_scores]), np.mean(test_base_scores), np.mean(test_base_worse))) print("score acc %.6f, avgmrr %.6f, worse decisions %.6f" % (np.mean([1 if score == 1 else 0 for score in test_score_scores]), np.mean(test_score_scores), np.mean(test_score_worse))) print("text acc %.6f, avgmrr %.6f, worse decisions %.6f" % (np.mean([1 if score == 1 else 0 for score in test_text_scores]), np.mean(test_text_scores), np.mean(test_text_worse))) '''