def main(argv): del argv # Unused. if FLAGS.debug: random.seed(0) reformulator_instance = reformulator.Reformulator( hparams_path=FLAGS.hparams_path, source_prefix='<en> <2en> ', ## FLAGS.source_prefix, out_dir=FLAGS.out_dir, environment_server_address=FLAGS.environment_server_address) environment_fn = environment_client.make_environment_reward_fn( FLAGS.environment_server_address, mode=FLAGS.mode, env_call_parallelism=FLAGS.env_sample_parallelism) eval_environment_fn = environment_client.make_environment_reward_fn( FLAGS.environment_server_address, mode='searchqa', env_call_parallelism=FLAGS.env_eval_parallelism) # Read data. questions, annotations, docid_2_answer = read_data( questions_file=FLAGS.train_questions, annotations_file=FLAGS.train_annotations, answers_file=FLAGS.train_data, preprocessing_mode=FLAGS.mode) dev_questions, dev_annotations, dev_docid_2_answer = read_data( questions_file=FLAGS.dev_questions, annotations_file=FLAGS.dev_annotations, answers_file=FLAGS.dev_data, preprocessing_mode=FLAGS.mode, max_lines=FLAGS.max_dev_examples) ### Inference: all_reformulated_question = [] custom_questions = [ 'What is the reimbursement policies and how to claim it?' ] ## what are the ways to save tax?,How many casual leaves one is entitled in a year?, why katappa killed bahubali? all_reformulated_question.append(custom_questions) # print('custom_questions:', type(custom_questions)) responses = reformulator_instance.reformulate( questions=custom_questions, inference_mode=reformulator_pb2.ReformulatorRequest.BEAM_SEARCH) ### GREEDY , SAMPLING , BEAM_SEARCH , TRIE_GREEDY , TRIE_SAMPLE , TRIE_BEAM_SEARCH # print('responses:', responses) # Discard answers. custom_reformulations = [[rf.reformulation for rf in rsp] for rsp in responses] all_reformulated_question.append(custom_reformulations) # for i in range(len(custom_reformulations)): # print('The reformulations of "', custom_questions[i], '" are:', custom_reformulations[i]) # ---------------------------------------------------------------------------------------------- print( '----------------------------reformulations of reformulation--------------------------------' ) # print('custom_reformulations:', type(custom_reformulations), len(custom_reformulations)) for j in range(len(custom_reformulations)): # print('-----------------------reformulation of ',j,' reformulations---------------------------') responses_of1st_infer = reformulator_instance.reformulate( questions=custom_reformulations[j], inference_mode=reformulator_pb2.ReformulatorRequest.BEAM_SEARCH) custom_reformulations_of1st_infer = [[rf.reformulation for rf in rsp] for rsp in responses_of1st_infer] # for k in range(len(custom_reformulations_of1st_infer)): # print('----------------------------------------------------') # print('The reformulations of "', custom_reformulations[j][k], '" are:', custom_reformulations_of1st_infer[k]) all_reformulated_question.append(custom_reformulations_of1st_infer) all_reformulated_question = flatten(all_reformulated_question) print('all_reformulated_question:', len(all_reformulated_question), len(set(all_reformulated_question)), set(all_reformulated_question)) all_reformulated_question = set(all_reformulated_question) outF = open("all_reformulated_question.txt", "w") for q in all_reformulated_question: # write line to output file outF.write(q) outF.write("\n") outF.close()
def main(argv): del argv # Unused. if FLAGS.debug: random.seed(0) reformulator_instance = reformulator.Reformulator( hparams_path=FLAGS.hparams_path, source_prefix=FLAGS.source_prefix, out_dir=FLAGS.out_dir, environment_server_address=FLAGS.environment_server_address) environment_fn = environment_client.make_environment_reward_fn( FLAGS.environment_server_address, mode=FLAGS.mode, env_call_parallelism=FLAGS.env_sample_parallelism) eval_environment_fn = environment_client.make_environment_reward_fn( FLAGS.environment_server_address, mode='searchqa', env_call_parallelism=FLAGS.env_eval_parallelism) # Read data. questions, annotations, docid_2_answer = read_data( questions_file=FLAGS.train_questions, annotations_file=FLAGS.train_annotations, answers_file=FLAGS.train_data, preprocessing_mode=FLAGS.mode) dev_questions, dev_annotations, dev_docid_2_answer = read_data( questions_file=FLAGS.dev_questions, annotations_file=FLAGS.dev_annotations, answers_file=FLAGS.dev_data, preprocessing_mode=FLAGS.mode, max_lines=FLAGS.max_dev_examples) # Summary writer that writes events to a folder. TensorBoard will later read # from it. summary_writer = tf.summary.FileWriter( os.path.join( FLAGS.tensorboard_dir, 'reformulator_and_selector_training_log_' + str(time.time()))) if FLAGS.enable_selector_training: selector_model = selector.Selector() last_save_step = 0 global_step = 0 for epoch in range(FLAGS.epochs): for batch_id, (questions_batch, annotations_batch) in enumerate( batch(questions, annotations, FLAGS.batch_size_train)): # Run eval every num_steps_per_eval batches. if global_step % FLAGS.num_steps_per_eval is 0: if FLAGS.debug: print('Running eval...') eval_start_time = time.time() if not FLAGS.enable_selector_training: eval_f1_avg = _run_reformulator_eval( dev_questions, dev_annotations, reformulator_instance, environment_fn, FLAGS.batch_size_eval) else: eval_f1_avg = _run_eval_with_selector( questions=dev_questions, annotations=dev_annotations, docid_2_answer=dev_docid_2_answer, reformulator_instance=reformulator_instance, selector_model=selector_model, batch_size=FLAGS.batch_size_eval, environment_fn=eval_environment_fn) # Correct the average F1 score for deleted datapoints in the SearchQA # dataset. if FLAGS.mode == 'searchqa': eval_f1_avg = _correct_searchqa_score(eval_f1_avg, dataset='dev') eval_time = time.time() - eval_start_time misc_utils.add_summary(summary_writer, global_step, tag='eval_f1_avg', value=eval_f1_avg) misc_utils.add_summary(summary_writer, global_step, tag='eval_time', value=eval_time) if FLAGS.debug: print('Avg F1 on dev: {}.'.format(eval_f1_avg)) print('Time to finish eval: {}'.format(eval_time)) start_time = time.time() if FLAGS.debug: print('Epoch {}, Batch {}.'.format(epoch, batch_id)) print('Question: [{}]; Id: {}'.format(questions_batch[0], annotations_batch[0])) # Retrieve rewrites for selector training using beam search. if FLAGS.enable_selector_training: responses_beam = reformulator_instance.reformulate( questions=questions_batch, inference_mode=reformulator_pb2.ReformulatorRequest. BEAM_SEARCH) # Discard answers. reformulations_beam = [[rf.reformulation for rf in rsp] for rsp in responses_beam] if FLAGS.enable_reformulator_training: # Train reformulator model. if FLAGS.debug: print('Training reformulator...') reformulator_loss, f1s, reformulations = reformulator_instance.train( sources=questions_batch, annotations=annotations_batch) f1_avg = f1s.mean() if [] in reformulations: if FLAGS.debug: print('Found empty rewrites! Skipping this batch.') continue if FLAGS.debug: print('Rewrite: {}'.format(safe_string(reformulations[0]))) print('Avg F1: {}'.format(f1_avg)) print('Loss : {}'.format(reformulator_loss)) # Write the f1_avg and loss to Tensorboard. misc_utils.add_summary(summary_writer, global_step, tag='f1_avg', value=f1_avg) misc_utils.add_summary(summary_writer, global_step, tag='reformulator_loss', value=reformulator_loss) # Train selector model. if FLAGS.enable_selector_training: (selector_questions, selector_answers, selector_scores) = query_environment( original_questions=questions_batch, rewrites=reformulations_beam, annotations=annotations_batch, environment_fn=eval_environment_fn, docid_2_answer=docid_2_answer, token_level_f1_scores=False) if FLAGS.debug: print('Training selector...') train_selector_loss, train_selector_accuracy = selector_model.train( selector_questions, selector_answers, selector_scores) # Regularly save a checkpoint. if global_step - last_save_step >= FLAGS.steps_per_save_selector: selector_model.save(str(global_step)) last_save_step = global_step print('Selector saved at step: {}'.format(global_step)) if FLAGS.debug: print('Train Accuracy: {}'.format(train_selector_accuracy)) print('Train Loss : {}'.format(train_selector_loss)) # Write the accuracy and loss to Tensorboard. misc_utils.add_summary(summary_writer, global_step, tag='train_selector_accuracy', value=train_selector_accuracy) misc_utils.add_summary(summary_writer, global_step, tag='train_selector_loss', value=train_selector_loss) iteration_time = time.time() - start_time if FLAGS.debug: print('Iteration time: {}'.format(iteration_time)) misc_utils.add_summary(summary_writer, global_step, tag='iteration_time', value=iteration_time) # Increment the global counter global_step += 1
domain_data['de'][ phrase] = transformedRow[: index] + transformedRow[ index + 1:] else: domain_data['en'][ phrase] = transformedRow[:index] + transformedRow[ index + 1:] return domain_data ## ---------------------------------------------------------------------------------------------------------- reformulator_instance = reformulator.Reformulator( hparams_path='px/nmt/example_configs/reformulator.json', source_prefix='<en> <2en> ', out_dir='./tmp/active-qa/reformulator/', environment_server_address='localhost:10000') ### To convert list of list of list... to single list. def flatten(lst): return sum(([x] if not isinstance(x, list) else flatten(x) for x in lst), []) def question_paraphrase(query): # print("query0:", type(query), query) questions = query all_reformulated_question = [] responses = reformulator_instance.reformulate(