def _external_eval(model, global_step, sess, hparams, iterator, iterator_feed_dict, tgt_file, label, summary_writer, save_on_best, avg_ckpts=False): """External evaluation such as BLEU and ROUGE scores.""" out_dir = hparams.out_dir if avg_ckpts: label = "avg_" + label utils.print_out("# External evaluation, global step %d" % global_step) sess.run(iterator.initializer, feed_dict=iterator_feed_dict) output = os.path.join(out_dir, "output_%s" % label) scores = nmt_utils.decode_and_evaluate( label, model, sess, output, ref_file=tgt_file, metrics=hparams.metrics, subword_option=hparams.subword_option, beam_width=hparams.beam_width, tgt_eos=hparams.eos, hparams=hparams, decode=True) # Save on best metrics if global_step > 0: for metric in hparams.metrics: if avg_ckpts: best_metric_label = "avg_best_" + metric else: best_metric_label = "best_" + metric utils.add_summary(summary_writer, global_step, "%s_%s" % (label, metric), scores[metric]) # metric: larger is better if save_on_best and scores[metric] > getattr( hparams, best_metric_label): setattr(hparams, best_metric_label, scores[metric]) model.saver.save(sess, os.path.join( getattr(hparams, "best_" + metric + "_dir"), "translate.ckpt"), global_step=model.global_step) utils.save_hparams(out_dir, hparams) return scores
def _internal_eval(hparams, model, global_step, sess, iterator, iterator_feed_dict, summary_writer, label): """Computing perplexity.""" utils.print_out("# Internal evaluation (perplexity), global step %d" % global_step) sess.run(iterator.initializer, feed_dict=iterator_feed_dict) ppl = model_helper.compute_perplexity(hparams, model, sess, label) utils.add_summary(summary_writer, global_step, "%s_ppl" % label, ppl) return ppl
def main(argv): del argv # Unused. if FLAGS.debug: random.seed(0) reformulator_instance = reformulator.Reformulator( hparams_path=FLAGS.hparams_path, source_prefix=FLAGS.source_prefix, out_dir=FLAGS.out_dir, environment_server_address=FLAGS.environment_server_address) environment_fn = environment_client.make_environment_reward_fn( FLAGS.environment_server_address, mode=FLAGS.mode, env_call_parallelism=FLAGS.env_sample_parallelism) eval_environment_fn = environment_client.make_environment_reward_fn( FLAGS.environment_server_address, mode='searchqa', env_call_parallelism=FLAGS.env_eval_parallelism) # Read data. questions, annotations, docid_2_answer = read_data( questions_file=FLAGS.train_questions, annotations_file=FLAGS.train_annotations, answers_file=FLAGS.train_data, preprocessing_mode=FLAGS.mode) dev_questions, dev_annotations, dev_docid_2_answer = read_data( questions_file=FLAGS.dev_questions, annotations_file=FLAGS.dev_annotations, answers_file=FLAGS.dev_data, preprocessing_mode=FLAGS.mode, max_lines=FLAGS.max_dev_examples) # Summary writer that writes events to a folder. TensorBoard will later read # from it. summary_writer = tf.summary.FileWriter( os.path.join( FLAGS.tensorboard_dir, 'reformulator_and_selector_training_log_' + str(time.time()))) if FLAGS.enable_selector_training: selector_model = selector.Selector() last_save_step = 0 global_step = 0 for epoch in range(FLAGS.epochs): for batch_id, (questions_batch, annotations_batch) in enumerate( batch(questions, annotations, FLAGS.batch_size_train)): # Run eval every num_steps_per_eval batches. if global_step % FLAGS.num_steps_per_eval is 0: if FLAGS.debug: print('Running eval...') eval_start_time = time.time() if not FLAGS.enable_selector_training: eval_f1_avg = _run_reformulator_eval( dev_questions, dev_annotations, reformulator_instance, environment_fn, FLAGS.batch_size_eval) else: eval_f1_avg = _run_eval_with_selector( questions=dev_questions, annotations=dev_annotations, docid_2_answer=dev_docid_2_answer, reformulator_instance=reformulator_instance, selector_model=selector_model, batch_size=FLAGS.batch_size_eval, environment_fn=eval_environment_fn) # Correct the average F1 score for deleted datapoints in the SearchQA # dataset. if FLAGS.mode == 'searchqa': eval_f1_avg = _correct_searchqa_score(eval_f1_avg, dataset='dev') eval_time = time.time() - eval_start_time misc_utils.add_summary(summary_writer, global_step, tag='eval_f1_avg', value=eval_f1_avg) misc_utils.add_summary(summary_writer, global_step, tag='eval_time', value=eval_time) if FLAGS.debug: print('Avg F1 on dev: {}.'.format(eval_f1_avg)) print('Time to finish eval: {}'.format(eval_time)) start_time = time.time() if FLAGS.debug: print('Epoch {}, Batch {}.'.format(epoch, batch_id)) print('Question: [{}]; Id: {}'.format(questions_batch[0], annotations_batch[0])) # Retrieve rewrites for selector training using beam search. if FLAGS.enable_selector_training: responses_beam = reformulator_instance.reformulate( questions=questions_batch, inference_mode=reformulator_pb2.ReformulatorRequest. BEAM_SEARCH) # Discard answers. reformulations_beam = [[rf.reformulation for rf in rsp] for rsp in responses_beam] if FLAGS.enable_reformulator_training: # Train reformulator model. if FLAGS.debug: print('Training reformulator...') reformulator_loss, f1s, reformulations = reformulator_instance.train( sources=questions_batch, annotations=annotations_batch) f1_avg = f1s.mean() if [] in reformulations: if FLAGS.debug: print('Found empty rewrites! Skipping this batch.') continue if FLAGS.debug: print('Rewrite: {}'.format(safe_string(reformulations[0]))) print('Avg F1: {}'.format(f1_avg)) print('Loss : {}'.format(reformulator_loss)) # Write the f1_avg and loss to Tensorboard. misc_utils.add_summary(summary_writer, global_step, tag='f1_avg', value=f1_avg) misc_utils.add_summary(summary_writer, global_step, tag='reformulator_loss', value=reformulator_loss) # Train selector model. if FLAGS.enable_selector_training: (selector_questions, selector_answers, selector_scores) = query_environment( original_questions=questions_batch, rewrites=reformulations_beam, annotations=annotations_batch, environment_fn=eval_environment_fn, docid_2_answer=docid_2_answer, token_level_f1_scores=False) if FLAGS.debug: print('Training selector...') train_selector_loss, train_selector_accuracy = selector_model.train( selector_questions, selector_answers, selector_scores) # Regularly save a checkpoint. if global_step - last_save_step >= FLAGS.steps_per_save_selector: selector_model.save(str(global_step)) last_save_step = global_step print('Selector saved at step: {}'.format(global_step)) if FLAGS.debug: print('Train Accuracy: {}'.format(train_selector_accuracy)) print('Train Loss : {}'.format(train_selector_loss)) # Write the accuracy and loss to Tensorboard. misc_utils.add_summary(summary_writer, global_step, tag='train_selector_accuracy', value=train_selector_accuracy) misc_utils.add_summary(summary_writer, global_step, tag='train_selector_loss', value=train_selector_loss) iteration_time = time.time() - start_time if FLAGS.debug: print('Iteration time: {}'.format(iteration_time)) misc_utils.add_summary(summary_writer, global_step, tag='iteration_time', value=iteration_time) # Increment the global counter global_step += 1
def train(hparams, scope=None, target_session=""): """Train a translation model.""" log_device_placement = hparams.log_device_placement out_dir = hparams.out_dir num_train_steps = hparams.num_train_steps steps_per_stats = hparams.steps_per_stats steps_per_external_eval = hparams.steps_per_external_eval steps_per_eval = 10 * steps_per_stats if not steps_per_external_eval: steps_per_external_eval = 5 * steps_per_eval if not hparams.attention: model_creator = nmt_model.Model else: # Attention if (hparams.encoder_type == "gnmt" or hparams.attention_architecture in ["gnmt", "gnmt_v2"]): model_creator = gnmt_model.GNMTModel elif hparams.attention_architecture == "standard": model_creator = attention_model.AttentionModel else: raise ValueError("Unknown attention architecture %s" % hparams.attention_architecture) combined_graph = tf.Graph() train_model = model_helper.create_train_model(model_creator, hparams, scope, graph=combined_graph) eval_model = model_helper.create_eval_model(model_creator, hparams, scope, graph=combined_graph) infer_model = model_helper.create_infer_model(model_creator, hparams, scope, graph=combined_graph) # Preload data for sample decoding. dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src) dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt) dev_ctx_file = None if hparams.ctx is not None: dev_ctx_file = "%s.%s" % (hparams.dev_prefix, hparams.ctx) sample_src_data = inference.load_data(dev_src_file) sample_tgt_data = inference.load_data(dev_tgt_file) sample_ctx_data = None if dev_ctx_file is not None: sample_ctx_data = inference.load_data(dev_ctx_file) sample_annot_data = None if hparams.dev_annotations is not None: sample_annot_data = inference.load_data(hparams.dev_annotations) summary_name = "train_log" model_dir = hparams.out_dir # Log and output files log_file = os.path.join(out_dir, "log_%d" % time.time()) log_f = tf.gfile.GFile(log_file, mode="a") utils.print_out("# log_file=%s" % log_file, log_f) # TensorFlow model config_proto = utils.get_config_proto( log_device_placement=log_device_placement, num_intra_threads=hparams.num_intra_threads, num_inter_threads=hparams.num_inter_threads) sess = tf.Session(target=target_session, config=config_proto, graph=combined_graph) with train_model.graph.as_default(): sess.run(tf.global_variables_initializer()) sess.run(tf.tables_initializer()) loaded_train_model, global_step = model_helper.create_or_load_model( train_model.model, model_dir, sess, "train") # Summary writer summary_writer = tf.summary.FileWriter(os.path.join(out_dir, summary_name), train_model.graph) # First evaluation run_full_eval(infer_model, sess, eval_model, sess, hparams, summary_writer, sample_src_data, sample_ctx_data, sample_tgt_data, sample_annot_data) last_stats_step = global_step last_eval_step = global_step last_external_eval_step = global_step # This is the training loop. stats, info, start_train_time = before_train(loaded_train_model, train_model, sess, global_step, hparams, log_f) while global_step < num_train_steps: ### Run a step ### start_time = time.time() try: step_result = loaded_train_model.train(sess) hparams.epoch_step += 1 except tf.errors.OutOfRangeError: # Finished going through the training dataset. Go to next epoch. hparams.epoch_step = 0 utils.print_out( "# Finished an epoch, step %d. Perform external evaluation" % global_step) run_sample_decode(infer_model, sess, hparams, summary_writer, sample_src_data, sample_ctx_data, sample_tgt_data, sample_annot_data) run_external_eval(infer_model, sess, hparams, summary_writer) sess.run(train_model.iterator.initializer, feed_dict={train_model.skip_count_placeholder: 0}) continue # Process step_result, accumulate stats, and write summary global_step, info["learning_rate"], step_summary = update_stats( stats, start_time, step_result) summary_writer.add_summary(step_summary, global_step) # Once in a while, we print statistics. if global_step - last_stats_step >= steps_per_stats: last_stats_step = global_step is_overflow = process_stats(stats, info, global_step, steps_per_stats, log_f) print_step_info(" ", global_step, info, _get_best_results(hparams), log_f) if is_overflow: break # Reset statistics stats = init_stats() if global_step - last_eval_step >= steps_per_eval: last_eval_step = global_step utils.print_out("# Save eval, global step %d" % global_step) utils.add_summary(summary_writer, global_step, "train_ppl", info["train_ppl"]) # Save checkpoint loaded_train_model.saver.save(sess, os.path.join(out_dir, "translate.ckpt"), global_step=global_step) # Evaluate on dev/test run_sample_decode(infer_model, sess, hparams, summary_writer, sample_src_data, sample_ctx_data, sample_tgt_data, sample_annot_data) dev_ppl, test_ppl = None, None # only evaluate perplexity when supervised learning if not hparams.use_rl: dev_ppl, test_ppl = run_internal_eval(eval_model, sess, hparams, summary_writer) if global_step - last_external_eval_step >= steps_per_external_eval: last_external_eval_step = global_step # Save checkpoint loaded_train_model.saver.save(sess, os.path.join(out_dir, "translate.ckpt"), global_step=global_step) run_sample_decode(infer_model, sess, hparams, summary_writer, sample_src_data, sample_ctx_data, sample_tgt_data, sample_annot_data) run_external_eval(infer_model, sess, hparams, summary_writer) # Done training loaded_train_model.saver.save(sess, os.path.join(out_dir, "translate.ckpt"), global_step=global_step) (result_summary, _, final_eval_metrics) = (run_full_eval( infer_model, sess, eval_model, sess, hparams, summary_writer, sample_src_data, sample_ctx_data, sample_tgt_data, sample_annot_data)) print_step_info("# Final, ", global_step, info, result_summary, log_f) utils.print_time("# Done training!", start_train_time) summary_writer.close() return final_eval_metrics, global_step