def _run_eval_with_selector(questions, annotations, docid_2_answer, reformulator_instance, selector_model, batch_size, environment_fn): """Runs a joined eval with the reformulator and selector model.""" f1s = [] for batch_id, (questions_batch, annotations_batch) in enumerate( batch(questions, annotations, batch_size)): responses = reformulator_instance.reformulate( questions=questions_batch, inference_mode=reformulator_pb2.ReformulatorRequest.BEAM_SEARCH) # Discard answers. reformulations = [[rf.reformulation for rf in rsp] for rsp in responses] question_and_rewrites, answers, scores = query_environment( original_questions=questions_batch, rewrites=reformulations, annotations=annotations_batch, environment_fn=environment_fn, docid_2_answer=docid_2_answer, token_level_f1_scores=True) f1s.append(selector_model.eval(question_and_rewrites, answers, scores)) if FLAGS.debug and batch_id == 0: print('Running Eval...') print('Questions: {}, Annotation: {}'.format( safe_string(questions_batch[0]), safe_string(annotations_batch[0]))) print('Rewrites: {}'.format(safe_string(reformulations[0]))) print('Answers and Scores: {}'.format( zip(safe_string(answers[0]), scores[0]))) return np.mean(f1s)
def main(argv): del argv # Unused. if FLAGS.debug: random.seed(0) reformulator_instance = reformulator.Reformulator( hparams_path=FLAGS.hparams_path, source_prefix=FLAGS.source_prefix, out_dir=FLAGS.out_dir, environment_server_address=FLAGS.environment_server_address) environment_fn = environment_client.make_environment_reward_fn( FLAGS.environment_server_address, mode=FLAGS.mode, env_call_parallelism=FLAGS.env_sample_parallelism) eval_environment_fn = environment_client.make_environment_reward_fn( FLAGS.environment_server_address, mode='searchqa', env_call_parallelism=FLAGS.env_eval_parallelism) # Read data. questions, annotations, docid_2_answer = read_data( questions_file=FLAGS.train_questions, annotations_file=FLAGS.train_annotations, answers_file=FLAGS.train_data, preprocessing_mode=FLAGS.mode) dev_questions, dev_annotations, dev_docid_2_answer = read_data( questions_file=FLAGS.dev_questions, annotations_file=FLAGS.dev_annotations, answers_file=FLAGS.dev_data, preprocessing_mode=FLAGS.mode, max_lines=FLAGS.max_dev_examples) # Summary writer that writes events to a folder. TensorBoard will later read # from it. summary_writer = tf.summary.FileWriter( os.path.join( FLAGS.tensorboard_dir, 'reformulator_and_selector_training_log_' + str(time.time()))) if FLAGS.enable_selector_training: selector_model = selector.Selector() last_save_step = 0 global_step = 0 for epoch in range(FLAGS.epochs): for batch_id, (questions_batch, annotations_batch) in enumerate( batch(questions, annotations, FLAGS.batch_size_train)): # Run eval every num_steps_per_eval batches. if global_step % FLAGS.num_steps_per_eval is 0: if FLAGS.debug: print('Running eval...') eval_start_time = time.time() if not FLAGS.enable_selector_training: eval_f1_avg = _run_reformulator_eval( dev_questions, dev_annotations, reformulator_instance, environment_fn, FLAGS.batch_size_eval) else: eval_f1_avg = _run_eval_with_selector( questions=dev_questions, annotations=dev_annotations, docid_2_answer=dev_docid_2_answer, reformulator_instance=reformulator_instance, selector_model=selector_model, batch_size=FLAGS.batch_size_eval, environment_fn=eval_environment_fn) # Correct the average F1 score for deleted datapoints in the SearchQA # dataset. if FLAGS.mode == 'searchqa': eval_f1_avg = _correct_searchqa_score(eval_f1_avg, dataset='dev') eval_time = time.time() - eval_start_time misc_utils.add_summary(summary_writer, global_step, tag='eval_f1_avg', value=eval_f1_avg) misc_utils.add_summary(summary_writer, global_step, tag='eval_time', value=eval_time) if FLAGS.debug: print('Avg F1 on dev: {}.'.format(eval_f1_avg)) print('Time to finish eval: {}'.format(eval_time)) start_time = time.time() if FLAGS.debug: print('Epoch {}, Batch {}.'.format(epoch, batch_id)) print('Question: [{}]; Id: {}'.format(questions_batch[0], annotations_batch[0])) # Retrieve rewrites for selector training using beam search. if FLAGS.enable_selector_training: responses_beam = reformulator_instance.reformulate( questions=questions_batch, inference_mode=reformulator_pb2.ReformulatorRequest. BEAM_SEARCH) # Discard answers. reformulations_beam = [[rf.reformulation for rf in rsp] for rsp in responses_beam] if FLAGS.enable_reformulator_training: # Train reformulator model. if FLAGS.debug: print('Training reformulator...') reformulator_loss, f1s, reformulations = reformulator_instance.train( sources=questions_batch, annotations=annotations_batch) f1_avg = f1s.mean() if [] in reformulations: if FLAGS.debug: print('Found empty rewrites! Skipping this batch.') continue if FLAGS.debug: print('Rewrite: {}'.format(safe_string(reformulations[0]))) print('Avg F1: {}'.format(f1_avg)) print('Loss : {}'.format(reformulator_loss)) # Write the f1_avg and loss to Tensorboard. misc_utils.add_summary(summary_writer, global_step, tag='f1_avg', value=f1_avg) misc_utils.add_summary(summary_writer, global_step, tag='reformulator_loss', value=reformulator_loss) # Train selector model. if FLAGS.enable_selector_training: (selector_questions, selector_answers, selector_scores) = query_environment( original_questions=questions_batch, rewrites=reformulations_beam, annotations=annotations_batch, environment_fn=eval_environment_fn, docid_2_answer=docid_2_answer, token_level_f1_scores=False) if FLAGS.debug: print('Training selector...') train_selector_loss, train_selector_accuracy = selector_model.train( selector_questions, selector_answers, selector_scores) # Regularly save a checkpoint. if global_step - last_save_step >= FLAGS.steps_per_save_selector: selector_model.save(str(global_step)) last_save_step = global_step print('Selector saved at step: {}'.format(global_step)) if FLAGS.debug: print('Train Accuracy: {}'.format(train_selector_accuracy)) print('Train Loss : {}'.format(train_selector_loss)) # Write the accuracy and loss to Tensorboard. misc_utils.add_summary(summary_writer, global_step, tag='train_selector_accuracy', value=train_selector_accuracy) misc_utils.add_summary(summary_writer, global_step, tag='train_selector_loss', value=train_selector_loss) iteration_time = time.time() - start_time if FLAGS.debug: print('Iteration time: {}'.format(iteration_time)) misc_utils.add_summary(summary_writer, global_step, tag='iteration_time', value=iteration_time) # Increment the global counter global_step += 1