def train(hparams, scope=None, target_session="", compute_ppl=0): """Train a translation model.""" log_device_placement = hparams.log_device_placement out_dir = hparams.out_dir num_train_steps = hparams.num_train_steps steps_per_stats = hparams.steps_per_stats steps_per_external_eval = hparams.steps_per_external_eval steps_per_eval = 10 * steps_per_stats avg_ckpts = hparams.avg_ckpts if not steps_per_external_eval: steps_per_external_eval = 5 * steps_per_eval if not hparams.attention: # choose this model model_creator = nmt_model.Model else: # Attention if (hparams.encoder_type == "gnmt" or hparams.attention_architecture in ["gnmt", "gnmt_v2"]): model_creator = gnmt_model.GNMTModel elif hparams.attention_architecture == "standard": model_creator = attention_model.AttentionModel else: raise ValueError("Unknown attention architecture %s" % hparams.attention_architecture) train_model = model_helper.create_train_model(model_creator, hparams, scope) eval_model = model_helper.create_eval_model(model_creator, hparams, scope) infer_model = model_helper.create_infer_model(model_creator, hparams, scope) # Preload data for sample decoding. dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src) dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt) sample_src_data = inference.load_data(dev_src_file) sample_tgt_data = inference.load_data(dev_tgt_file) wsd_src_file = "%s" % (hparams.sample_prefix) wsd_src_data = inference.load_data(wsd_src_file) summary_name = "train_log" model_dir = hparams.out_dir # Log and output files log_file = os.path.join(out_dir, "log_%d" % time.time()) log_f = tf.gfile.GFile(log_file, mode="a") utils.print_out("# log_file=%s" % log_file, log_f) # TensorFlow model config_proto = utils.get_config_proto( log_device_placement=log_device_placement, num_intra_threads=hparams.num_intra_threads, num_inter_threads=hparams.num_inter_threads) train_sess = tf.Session(target=target_session, config=config_proto, graph=train_model.graph) eval_sess = tf.Session(target=target_session, config=config_proto, graph=eval_model.graph) infer_sess = tf.Session(target=target_session, config=config_proto, graph=infer_model.graph) with train_model.graph.as_default(): loaded_train_model, global_step = model_helper.create_or_load_model( train_model.model, model_dir, train_sess, "train") # Summary writer summary_writer = tf.summary.FileWriter(os.path.join(out_dir, summary_name), train_model.graph) # First evaluation ''' run_full_eval( model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data, avg_ckpts) ''' last_stats_step = global_step last_eval_step = global_step last_external_eval_step = global_step # This is the training loop. stats, info, start_train_time = before_train(loaded_train_model, train_model, train_sess, global_step, hparams, log_f) end_step = global_step + 100 while global_step < end_step: # num_train_steps ### Run a step ### start_time = time.time() try: # then forward inference result to WSD, get reward step_result = loaded_train_model.train(train_sess) # forward reward to placeholder of loaded_train_model, and write a new train function where loss = loss*reward hparams.epoch_step += 1 except tf.errors.OutOfRangeError: # Finished going through the training dataset. Go to next epoch. hparams.epoch_step = 0 utils.print_out( "# Finished an epoch, step %d. Perform external evaluation" % global_step) # run_sample_decode(infer_model, infer_sess, model_dir, hparams, # summary_writer, sample_src_data, sample_tgt_data) # only for pretrain # run_external_eval(infer_model, infer_sess, model_dir, hparams, # summary_writer) if avg_ckpts: run_avg_external_eval(infer_model, infer_sess, model_dir, hparams, summary_writer, global_step) train_sess.run(train_model.iterator.initializer, feed_dict={train_model.skip_count_placeholder: 0}) continue # Process step_result, accumulate stats, and write summary global_step, info["learning_rate"], step_summary = update_stats( stats, start_time, step_result, hparams) summary_writer.add_summary(step_summary, global_step) if compute_ppl: run_internal_eval(eval_model, eval_sess, model_dir, hparams, summary_writer) # Once in a while, we print statistics. if global_step - last_stats_step >= steps_per_stats: last_stats_step = global_step is_overflow = process_stats(stats, info, global_step, steps_per_stats, log_f) print_step_info(" ", global_step, info, _get_best_results(hparams), log_f) if is_overflow: break # Reset statistics stats = init_stats() if global_step - last_eval_step >= steps_per_eval: last_eval_step = global_step utils.print_out("# Save eval, global step %d" % global_step) utils.add_summary(summary_writer, global_step, "train_ppl", info["train_ppl"]) # Save checkpoint loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "translate.ckpt"), global_step=global_step) # Evaluate on dev/test run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, sample_src_data, sample_tgt_data) run_internal_eval(eval_model, eval_sess, model_dir, hparams, summary_writer) if global_step - last_external_eval_step >= steps_per_external_eval: last_external_eval_step = global_step # Save checkpoint loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "translate.ckpt"), global_step=global_step) run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, sample_src_data, sample_tgt_data) run_external_eval(infer_model, infer_sess, model_dir, hparams, summary_writer) if avg_ckpts: run_avg_external_eval(infer_model, infer_sess, model_dir, hparams, summary_writer, global_step) # Done training loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "translate.ckpt"), global_step=global_step) '''
def train(hparams): """Train a seq2seq model.""" log_device_placement = hparams.log_device_placement out_dir = hparams.out_dir num_train_steps = hparams.num_train_steps steps_per_stats = hparams.steps_per_stats model_creator = model.Model train_model = model_helper.create_train_model(model_creator, hparams) summary_name = "train_log" model_dir = hparams.out_dir # Log and output files log_file = os.path.join(out_dir, "log_%d" % time.time()) log_f = tf.gfile.GFile(log_file, mode="a") utils.print_out("# log_files=%s" % log_file, log_f) # TensorFlow model config_proto = utils.get_config_proto( log_device_placement=log_device_placement, num_intra_threads=hparams.num_intra_threads, num_inter_threads=hparams.num_inter_threads) train_sess = tf.Session(config=config_proto, graph=train_model.graph) with train_model.graph.as_default(): loaded_train_model, global_step = model_helper.create_or_load_model( train_model.model, model_dir, train_sess, "train") # Summary writer summary_writer = tf.summary.FileWriter( os.path.join(out_dir, summary_name), train_model.graph) last_stats_step = global_step # This is the training loop. stats, info, start_train_time = before_train( loaded_train_model, train_model, train_sess, global_step, hparams, log_f) while global_step < num_train_steps: ### Run a step ### start_time = time.time() try: step_result = loaded_train_model.train(train_sess) hparams.epoch_step += 1 except tf.errors.OutOfRangeError: # Finished going through the training dataset. Go to next epoch. hparams.epoch_step = 0 utils.print_out("# Finished an epoch, step %d." % global_step) train_sess.run( train_model.iterator.initializer, feed_dict={train_model.skip_count_placeholder: 0}) continue # Process step_result, accumulate stats, and write summary global_step, info["learning_rate"], step_summary = update_stats( stats, start_time, step_result) summary_writer.add_summary(step_summary, global_step) # Once in a while, we print statistics. if global_step - last_stats_step >= steps_per_stats: last_stats_step = global_step is_overflow = process_stats(stats, info, global_step, steps_per_stats, log_f) print_step_info(" ", global_step, info, log_f) if is_overflow: break # Reset statistics stats = init_stats() # Done training loaded_train_model.saver.save( train_sess, os.path.join(out_dir, "seq2seq.ckpt"), global_step=global_step) summary_writer.close() return global_step
def train(hparams, identity, scope=None, target_session=""): """main loop to train the dialogue model. identity is used.""" out_dir = hparams.out_dir steps_per_stats = hparams.steps_per_stats steps_per_internal_eval = 3 * steps_per_stats model_creator = diag_model.Model train_model = model_helper.create_train_model(model_creator, hparams, scope) model_dir = hparams.out_dir # Log and output files log_file = os.path.join(out_dir, identity+"log_%d" % time.time()) log_f = tf.gfile.GFile(log_file, mode="a") utils.print_out("# log_file=%s" % log_file, log_f) avg_step_time = 0.0 # load TensorFlow session and model config_proto = utils.get_config_proto( log_device_placement=hparams.log_device_placement, allow_soft_placement=True) train_sess = tf.Session( target=target_session, config=config_proto, graph=train_model.graph) train_handle = train_sess.run(train_model.train_iterator.string_handle()) with train_model.graph.as_default(): loaded_train_model, global_step = model_helper.create_or_load_model( train_model.model, model_dir, train_sess, "train") # initialize summary writer summary_writer = tf.summary.FileWriter( os.path.join(out_dir, "train_log"), train_model.graph) last_stats_step = global_step last_eval_step = global_step # initialize training stats. step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0 checkpoint_total_count = 0.0 speed, train_ppl = 0.0, 0.0 start_train_time = time.time() utils.print_out( "# Start step %d, lr %g, %s" % (global_step, loaded_train_model.learning_rate.eval(session=train_sess), time.ctime()), log_f) # initialize iterators skip_count = hparams.batch_size * hparams.epoch_step utils.print_out("# Init train iterator, skipping %d elements" % skip_count) train_sess.run( train_model.train_iterator.initializer, feed_dict={train_model.skip_count_placeholder: skip_count}) # main training loop while global_step < hparams.num_train_steps: start_time = time.time() try: # run a step step_result = loaded_train_model.train(train_sess, train_handle) (_, step_loss, all_summaries, step_predict_count, step_summary, global_step, step_word_count, batch_size, _, _, words1, words2, mask1, mask2) = step_result hparams.epoch_step += 1 except tf.errors.OutOfRangeError: # finished an epoch hparams.epoch_step = 0 utils.print_out("# Finished an epoch, step %d." % global_step) train_sess.run( train_model.train_iterator.initializer, feed_dict={train_model.skip_count_placeholder: 0}) continue # Write step summary. summary_writer.add_summary(step_summary, global_step) for key in all_summaries: utils.add_summary(summary_writer, global_step, key, all_summaries[key]) # update statistics step_time += (time.time() - start_time) checkpoint_loss += (step_loss * batch_size) checkpoint_predict_count += step_predict_count checkpoint_total_count += float(step_word_count) if global_step - last_stats_step >= steps_per_stats: # print statistics for the previous epoch and save the model. last_stats_step = global_step avg_step_time = step_time / steps_per_stats utils.add_summary(summary_writer, global_step, "step_time", avg_step_time) train_ppl = utils.safe_exp(checkpoint_loss / checkpoint_predict_count) speed = checkpoint_total_count / (1000 * step_time) if math.isnan(train_ppl): break # Reset timer and loss. step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0 checkpoint_total_count = 0.0 # save the model loaded_train_model.saver.save( train_sess, os.path.join(out_dir, "dialogue.ckpt"), global_step=global_step) # print the dialogue if in debug mode if hparams.debug: utils.print_current_dialogue(words1, words2, mask1, mask2) # write out internal evaluation if global_step - last_eval_step >= steps_per_internal_eval: last_eval_step = global_step utils.print_out("# Internal Evaluation. global step %d" % global_step) utils.add_summary(summary_writer, global_step, "train_ppl", train_ppl) # finished training loaded_train_model.saver.save( train_sess, os.path.join(out_dir, "dialogue.ckpt"), global_step=global_step) result_summary = "" utils.print_out( "# Final, step %d lr %g " "step-time %.2f wps %.2fK ppl %.2f, %s, %s" % (global_step, loaded_train_model.learning_rate.eval(session=train_sess), avg_step_time, speed, train_ppl, result_summary, time.ctime()), log_f) utils.print_time("# Done training!", start_train_time) utils.print_out("# Start evaluating saved best models.") summary_writer.close()
def train(hparams, scope=None, target_session=""): """Train a translation model.""" log_device_placement = hparams.log_device_placement out_dir = hparams.out_dir num_train_steps = hparams.num_train_steps steps_per_stats = hparams.steps_per_stats steps_per_external_eval = hparams.steps_per_external_eval steps_per_eval = 10 * steps_per_stats avg_ckpts = hparams.avg_ckpts if not steps_per_external_eval: steps_per_external_eval = 5 * steps_per_eval # Create model model_creator = get_model_creator(hparams) train_model = model_helper.create_train_model(model_creator, hparams, scope) eval_model = model_helper.create_eval_model(model_creator, hparams, scope) infer_model = model_helper.create_infer_model(model_creator, hparams, scope) # Preload data for sample decoding. dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src) dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt) sample_src_data = inference.load_data(dev_src_file) sample_tgt_data = inference.load_data(dev_tgt_file) summary_name = "train_log" model_dir = hparams.out_dir # Log and output files log_file = os.path.join(out_dir, "log_%d" % time.time()) log_f = tf.gfile.GFile(log_file, mode="a") utils.print_out("# log_file=%s" % log_file, log_f) # TensorFlow model config_proto = utils.get_config_proto( log_device_placement=log_device_placement, num_intra_threads=hparams.num_intra_threads, num_inter_threads=hparams.num_inter_threads) train_sess = tf.Session(target=target_session, config=config_proto, graph=train_model.graph) eval_sess = tf.Session(target=target_session, config=config_proto, graph=eval_model.graph) infer_sess = tf.Session(target=target_session, config=config_proto, graph=infer_model.graph) with train_model.graph.as_default(): loaded_train_model, global_step = model_helper.create_or_load_model( train_model.model, model_dir, train_sess, "train") # Summary writer summary_writer = tf.summary.FileWriter(os.path.join(out_dir, summary_name), train_model.graph) # First evaluation run_full_eval(model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data, avg_ckpts) last_stats_step = global_step last_eval_step = global_step last_external_eval_step = global_step # This is the training loop. stats, info, start_train_time = before_train(loaded_train_model, train_model, train_sess, global_step, hparams, log_f) while global_step < num_train_steps: ### Run a step ### start_time = time.time() try: step_result = loaded_train_model.train(train_sess) hparams.epoch_step += 1 except tf.errors.OutOfRangeError: # Finished going through the training dataset. Go to next epoch. hparams.epoch_step = 0 utils.print_out( "# Finished an epoch, step %d. Perform external evaluation" % global_step) run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, sample_src_data, sample_tgt_data) run_external_eval(infer_model, infer_sess, model_dir, hparams, summary_writer) if avg_ckpts: run_avg_external_eval(infer_model, infer_sess, model_dir, hparams, summary_writer, global_step) train_sess.run(train_model.iterator.initializer, feed_dict={train_model.skip_count_placeholder: 0}) continue # Process step_result, accumulate stats, and write summary global_step, info["learning_rate"], step_summary = update_stats( stats, start_time, step_result) summary_writer.add_summary(step_summary, global_step) # Once in a while, we print statistics. if global_step - last_stats_step >= steps_per_stats: last_stats_step = global_step is_overflow = process_stats(stats, info, global_step, steps_per_stats, log_f) print_step_info(" ", global_step, info, get_best_results(hparams), log_f) if is_overflow: break # Reset statistics stats = init_stats() if global_step - last_eval_step >= steps_per_eval: last_eval_step = global_step utils.print_out("# Save eval, global step %d" % global_step) add_info_summaries(summary_writer, global_step, info) # Save checkpoint loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "translate.ckpt"), global_step=global_step) # Evaluate on dev/test run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, sample_src_data, sample_tgt_data) run_internal_eval(eval_model, eval_sess, model_dir, hparams, summary_writer) if global_step - last_external_eval_step >= steps_per_external_eval: last_external_eval_step = global_step # Save checkpoint loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "translate.ckpt"), global_step=global_step) run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, sample_src_data, sample_tgt_data) run_external_eval(infer_model, infer_sess, model_dir, hparams, summary_writer) if avg_ckpts: run_avg_external_eval(infer_model, infer_sess, model_dir, hparams, summary_writer, global_step) # Done training loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "translate.ckpt"), global_step=global_step) (result_summary, _, final_eval_metrics) = (run_full_eval( model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data, avg_ckpts)) print_step_info("# Final, ", global_step, info, result_summary, log_f) utils.print_time("# Done training!", start_train_time) summary_writer.close() utils.print_out("# Start evaluating saved best models.") for metric in hparams.metrics: best_model_dir = getattr(hparams, "best_" + metric + "_dir") summary_writer = tf.summary.FileWriter( os.path.join(best_model_dir, summary_name), infer_model.graph) result_summary, best_global_step, _ = run_full_eval( best_model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data) print_step_info("# Best %s, " % metric, best_global_step, info, result_summary, log_f) summary_writer.close() if avg_ckpts: best_model_dir = getattr(hparams, "avg_best_" + metric + "_dir") summary_writer = tf.summary.FileWriter( os.path.join(best_model_dir, summary_name), infer_model.graph) result_summary, best_global_step, _ = run_full_eval( best_model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data) print_step_info("# Averaged Best %s, " % metric, best_global_step, info, result_summary, log_f) summary_writer.close() return final_eval_metrics, global_step
def train(hps, scope=None, target_session=""): """Train a translation model.""" log_device_placement = hps.log_device_placement out_dir = hps.out_dir num_train_steps = hps.num_train_steps steps_per_stats = hps.steps_per_stats steps_per_external_eval = hps.steps_per_external_eval steps_per_eval = 100 * steps_per_stats if not steps_per_external_eval: steps_per_external_eval = 5 * steps_per_eval if hps.attention_architecture == "baseline": model_creator = AttentionModel else: model_creator = AttentionHistoryModel train_model = model_helper.create_train_model(model_creator, hps, scope) eval_model = model_helper.create_eval_model(model_creator, hps, scope) infer_model = model_helper.create_infer_model(model_creator, hps, scope) # Preload data for sample decoding. article_filenames = [] abstract_filenames = [] art_dir = hps.data_dir + '/article' abs_dir = hps.data_dir + '/abstract' for file in os.listdir(art_dir): if file.startswith(hps.dev_prefix): article_filenames.append(art_dir + "/" + file) for file in os.listdir(abs_dir): if file.startswith(hps.dev_prefix): abstract_filenames.append(abs_dir + "/" + file) # if random_decode: # """if this is a random sampling process during training""" decode_id = random.randint(0, len(article_filenames) - 1) single_article_file = article_filenames[decode_id] single_abstract_file = abstract_filenames[decode_id] dev_src_file = single_article_file dev_tgt_file = single_abstract_file sample_src_data = inference_base_model.load_data(dev_src_file) sample_tgt_data = inference_base_model.load_data(dev_tgt_file) summary_name = "train_log" model_dir = hps.out_dir # Log and output files log_file = os.path.join(out_dir, "log_%d" % time.time()) log_f = tf.gfile.GFile(log_file, mode="a") utils.print_out("# log_file=%s" % log_file, log_f) avg_step_time = 0.0 # TensorFlow model config_proto = utils.get_config_proto( log_device_placement=log_device_placement) train_sess = tf.Session(target=target_session, config=config_proto, graph=train_model.graph) eval_sess = tf.Session(target=target_session, config=config_proto, graph=eval_model.graph) infer_sess = tf.Session(target=target_session, config=config_proto, graph=infer_model.graph) with train_model.graph.as_default(): loaded_train_model, global_step = model_helper.create_or_load_model( train_model.model, model_dir, train_sess, "train") # Summary writer summary_writer = tf.summary.FileWriter(os.path.join(out_dir, summary_name), train_model.graph) # First evaluation # run_full_eval( # model_dir, infer_model, infer_sess, # eval_model, eval_sess, hps, # summary_writer,sample_src_data,sample_tgt_data) last_stats_step = global_step last_eval_step = global_step last_external_eval_step = global_step # This is the training loop. stats = init_stats() speed, train_ppl = 0.0, 0.0 start_train_time = time.time() utils.print_out( "# Start step %d, lr %g, %s" % (global_step, loaded_train_model.learning_rate.eval( session=train_sess), time.ctime()), log_f) # Initialize all of the iterators skip_count = hps.batch_size * hps.epoch_step utils.print_out("# Init train iterator, skipping %d elements" % skip_count) train_sess.run(train_model.iterator.initializer, feed_dict={train_model.skip_count_placeholder: skip_count}) epoch_step = 0 while global_step < num_train_steps: ### Run a step ### start_time = time.time() try: step_result = loaded_train_model.train(train_sess) epoch_step += 1 except tf.errors.OutOfRangeError: # Finished going through the training dataset. Go to next epoch. epoch_step = 0 utils.print_out( "# Finished an epoch, step %d. Perform external evaluation" % global_step) run_sample_decode(infer_model, infer_sess, model_dir, hps, summary_writer, sample_src_data, sample_tgt_data) dev_scores, test_scores, _ = run_external_eval( infer_model, infer_sess, model_dir, hps, summary_writer) train_sess.run(train_model.iterator.initializer, feed_dict={train_model.skip_count_placeholder: 0}) continue # Write step summary and accumulate statistics global_step = update_stats(stats, summary_writer, start_time, step_result) # Once in a while, we print statistics. if global_step - last_stats_step >= steps_per_stats: last_stats_step = global_step is_overflow = check_stats(stats, global_step, steps_per_stats, hps, log_f) if is_overflow: break # Reset statistics stats = init_stats() if global_step - last_eval_step >= steps_per_eval: last_eval_step = global_step utils.print_out("# Save eval, global step %d" % global_step) utils.add_summary(summary_writer, global_step, "train_ppl", train_ppl) # Save checkpoint loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "summarized.ckpt"), global_step=global_step) # Evaluate on dev/test run_sample_decode(infer_model, infer_sess, model_dir, hps, summary_writer, sample_src_data, sample_tgt_data) dev_ppl, test_ppl = run_internal_eval(eval_model, eval_sess, model_dir, hps, summary_writer) if global_step - last_external_eval_step >= steps_per_external_eval: last_external_eval_step = global_step # Save checkpoint loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "summarized.ckpt"), global_step=global_step) run_sample_decode(infer_model, infer_sess, model_dir, hps, summary_writer, sample_src_data, sample_tgt_data) dev_scores, test_scores, _ = run_external_eval( infer_model, infer_sess, model_dir, hps, summary_writer) # Done training loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "summarized.ckpt"), global_step=global_step) result_summary, _, dev_scores, test_scores, dev_ppl, test_ppl = run_full_eval( model_dir, infer_model, infer_sess, eval_model, eval_sess, hps, summary_writer, sample_src_data, sample_tgt_data) utils.print_out( "# Final, step %d lr %g " "step-time %.2f wps %.2fK ppl %.2f, %s, %s" % (global_step, loaded_train_model.learning_rate.eval(session=train_sess), avg_step_time, speed, train_ppl, result_summary, time.ctime()), log_f) utils.print_time("# Done training!", start_train_time) summary_writer.close() utils.print_out("# Start evaluating saved best models.") for metric in hps.metrics: best_model_dir = getattr(hps, "best_" + metric + "_dir") summary_writer = tf.summary.FileWriter( os.path.join(best_model_dir, summary_name), infer_model.graph) result_summary, best_global_step, _, _, _, _ = run_full_eval( best_model_dir, infer_model, infer_sess, eval_model, eval_sess, hps, summary_writer, sample_src_data, sample_tgt_data) utils.print_out( "# Best %s, step %d " "step-time %.2f wps %.2fK, %s, %s" % (metric, best_global_step, avg_step_time, speed, result_summary, time.ctime()), log_f) summary_writer.close() return (dev_scores, test_scores, dev_ppl, test_ppl, global_step)
def train(flags): """Train the policy gradient model. """ out_dir = flags.out_dir num_train_steps = flags.num_train_steps steps_per_infer = flags.steps_per_infer # Create model for train, infer mode model_creator = get_model_creator(flags) train_model = model_helper.create_train_model(flags, model_creator) infer_model = model_helper.create_infer_model(flags, model_creator) # TODO. set for distributed training and multi gpu config_proto = tf.ConfigProto(allow_soft_placement=True) config_proto.gpu_options.allow_growth = True # Session for train, infer train_sess = tf.Session( config=config_proto, graph=train_model.graph) infer_sess = tf.Session( config=config_proto, graph=infer_model.graph) # Load the train model if there's the file in the directory # otherwise, initialize vars in the train model with train_model.graph.as_default(): loaded_train_model, global_step = model_helper.create_or_load_model( out_dir, train_model.model, train_sess) # Summary train_summary = "train_log" infer_summary = "infer_log" # Summary writer for train, infer train_summary_writer = tf.summary.FileWriter( os.path.join(out_dir, train_summary), train_model.graph) infer_summary_writer = tf.summary.FileWriter( os.path.join(out_dir, infer_summary)) # First evaluation run_infer(infer_model, out_dir, infer_sess) # Initialize step var last_infer_steps = global_step # Training loop while global_step < num_train_steps: output_tuple = loaded_train_model.train(train_sess) global_step = output_tuple.global_step train_summary = output_tuple.train_summary # Update train summary train_summary_writer.add_summary(train_summary, global_step) print('current global_step: {}'.format(global_step)) # Evaluate the model for steps_per_infer if global_step - last_infer_steps >= steps_per_infer: # Save checkpoint loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "rl.ckpt"), global_step) last_infer_steps = global_step output_tuple = run_infer(infer_model, out_dir, infer_sess) infer_summary = output_tuple.infer_summary # Update infer summary infer_summary_writer.add_summary(infer_summary, global_step) # Done training loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "rl.ckpt"), global_step) print('Train done')
def train(hparams, scope=None, target_session=""): """Train a translation model.""" log_device_placement = hparams.log_device_placement out_dir = hparams.out_dir num_train_steps = hparams.num_train_steps steps_per_stats = hparams.steps_per_stats steps_per_external_eval = hparams.steps_per_external_eval steps_per_eval = 10 * steps_per_stats if not steps_per_external_eval: steps_per_external_eval = 5 * steps_per_eval if not hparams.attention: model_creator = nmt_model.Model elif hparams.attention_architecture == "standard": model_creator = attention_model.AttentionModel elif hparams.attention_architecture in ["gnmt", "gnmt_v2"]: model_creator = gnmt_model.GNMTModel else: raise ValueError("Unknown model architecture") train_model = model_helper.create_train_model(model_creator, hparams, scope) eval_model = model_helper.create_eval_model(model_creator, hparams, scope) infer_model = model_helper.create_infer_model(model_creator, hparams, scope) # Preload data for sample decoding. dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src) dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt) sample_src_data = inference.load_data(dev_src_file) sample_tgt_data = inference.load_data(dev_tgt_file) summary_name = "train_log" model_dir = hparams.out_dir # Log and output files log_file = os.path.join(out_dir, "log_%d" % time.time()) log_f = tf.gfile.GFile(log_file, mode="w") utils.print_out("# log_file=%s" % log_file, log_f) avg_step_time = 0.0 # TensorFlow model config_proto = utils.get_config_proto( log_device_placement=log_device_placement) train_sess = tf.Session(target=target_session, config=config_proto, graph=train_model.graph) eval_sess = tf.Session(target=target_session, config=config_proto, graph=eval_model.graph) infer_sess = tf.Session(target=target_session, config=config_proto, graph=infer_model.graph) with train_model.graph.as_default(): loaded_train_model, global_step = model_helper.create_or_load_model( train_model.model, model_dir, train_sess, "train") # Summary writer summary_writer = tf.summary.FileWriter(os.path.join(out_dir, summary_name), train_model.graph) # First evaluation run_full_eval(model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data) last_stats_step = global_step last_eval_step = global_step last_external_eval_step = global_step # This is the training loop. step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0 checkpoint_total_count = 0.0 speed, train_ppl = 0.0, 0.0 start_train_time = time.time() utils.print_out( "# Start step %d, lr %g, %s" % (global_step, loaded_train_model.learning_rate.eval( session=train_sess), time.ctime()), log_f) # Initialize all of the iterators skip_count = hparams.batch_size * hparams.epoch_step utils.print_out("# Init train iterator, skipping %d elements" % skip_count) if hparams.curriculum == 'none': train_sess.run( train_model.iterator.initializer, feed_dict={train_model.skip_count_placeholder: skip_count}) else: if hparams.curriculum == 'predictive_gain': exp3s = Exp3S(hparams.num_curriculum_buckets, 0.001, 0, 0.05) elif hparams.curriculum == 'look_back_and_forward': curriculum_point = 0 handle = train_model.iterator.handle for i in range(hparams.num_curriculum_buckets): train_sess.run( train_model.iterator.initializer[i].initializer, feed_dict={train_model.skip_count_placeholder: skip_count}) iterator_handles = [ train_sess.run( train_model.iterator.initializer[i].string_handle(), feed_dict={train_model.skip_count_placeholder: skip_count}) for i in range(hparams.num_curriculum_buckets) ] utils.print_out("Starting training") while global_step < num_train_steps: ### Run a step ### start_time = time.time() try: if hparams.curriculum != 'none': if hparams.curriculum == 'predictive_gain': lesson = exp3s.draw_task() elif hparams.curriculum == 'look_back_and_forward': if curriculum_point == hparams.num_curriculum_buckets: lesson = np.random.randint( low=0, high=hparams.num_curriculum_buckets) else: lesson = curriculum_point if np.random.random_sample( ) < 0.8 else np.random.randint( low=0, high=hparams.num_curriculum_buckets) step_result = loaded_train_model.train( hparams, train_sess, handle=handle, iterator_handle=iterator_handles[lesson], use_fed_source_placeholder=loaded_train_model. use_fed_source, fed_source_placeholder=loaded_train_model.fed_source) (_, step_loss, step_predict_count, step_summary, global_step, step_word_count, batch_size, source) = step_result if hparams.curriculum == 'predictive_gain': new_loss = train_sess.run( [loaded_train_model.train_loss], feed_dict={ handle: iterator_handles[lesson], loaded_train_model.use_fed_source: True, loaded_train_model.fed_source: source }) # new_loss = loaded_train_model.train_loss.eval( # session=train_sess, # feed_dict={ # handle: iterator_handles[lesson], # loaded_train_model.use_fed_source: True, # loaded_train_model.fed_source: source # }) # utils.print_out("lesson: %s, step loss: %s, new_loss: %s" % (lesson, step_loss, new_loss)) # utils.print_out("exp3s dist: %s" % (exp3s.pi, )) curriculum_point_a = lesson * ( hparams.src_max_len // hparams.num_curriculum_buckets) + 1 curriculum_point_b = ( lesson + 1) * (hparams.src_max_len // hparams.num_curriculum_buckets) + 1 v = step_loss - new_loss exp3s.update_w( v, float(curriculum_point_a + curriculum_point_b) / 2.0) elif hparams.curriculum == 'look_back_and_forward': utils.print_out("step loss: %s, lesson: %s" % (step_loss, lesson)) curriculum_point_a = curriculum_point * ( hparams.src_max_len // hparams.num_curriculum_buckets) + 1 curriculum_point_b = (curriculum_point + 1) * ( hparams.src_max_len // hparams.num_curriculum_buckets) + 1 if step_loss < (hparams.curriculum_progress_loss * (float(curriculum_point_a + curriculum_point_b) / 2.0)): curriculum_point += 1 else: step_result = loaded_train_model.train(hparams, train_sess) (_, step_loss, step_predict_count, step_summary, global_step, step_word_count, batch_size) = step_result hparams.epoch_step += 1 except tf.errors.OutOfRangeError: # Finished going through the training dataset. Go to next epoch. hparams.epoch_step = 0 # utils.print_out( # "# Finished an epoch, step %d. Perform external evaluation" % # global_step) # run_sample_decode(infer_model, infer_sess, # model_dir, hparams, summary_writer, sample_src_data, # sample_tgt_data) # dev_scores, test_scores, _ = run_external_eval( # infer_model, infer_sess, model_dir, # hparams, summary_writer) if hparams.curriculum == 'none': train_sess.run( train_model.iterator.initializer, feed_dict={train_model.skip_count_placeholder: 0}) else: train_sess.run( train_model.iterator.initializer[lesson].initializer, feed_dict={train_model.skip_count_placeholder: 0}) continue # Write step summary. summary_writer.add_summary(step_summary, global_step) # update statistics step_time += (time.time() - start_time) checkpoint_loss += (step_loss * batch_size) checkpoint_predict_count += step_predict_count checkpoint_total_count += float(step_word_count) # Once in a while, we print statistics. if global_step - last_stats_step >= steps_per_stats: if hparams.curriculum == 'predictive_gain': utils.print_out("lesson: %s, step loss: %s, new_loss: %s" % (lesson, step_loss, new_loss)) utils.print_out("exp3s dist: %s" % (exp3s.pi, )) last_stats_step = global_step # Print statistics for the previous epoch. avg_step_time = step_time / steps_per_stats train_ppl = utils.safe_exp(checkpoint_loss / checkpoint_predict_count) speed = checkpoint_total_count / (1000 * step_time) utils.print_out( " global step %d lr %g " "step-time %.2fs wps %.2fK ppl %.2f %s" % (global_step, loaded_train_model.learning_rate.eval(session=train_sess), avg_step_time, speed, train_ppl, _get_best_results(hparams)), log_f) if math.isnan(train_ppl): break # Reset timer and loss. step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0 checkpoint_total_count = 0.0 if global_step - last_eval_step >= steps_per_eval: last_eval_step = global_step utils.print_out("# Save eval, global step %d" % global_step) utils.add_summary(summary_writer, global_step, "train_ppl", train_ppl) # Save checkpoint loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "translate.ckpt"), global_step=global_step) # Evaluate on dev/test run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, sample_src_data, sample_tgt_data) dev_ppl, test_ppl = run_internal_eval(eval_model, eval_sess, model_dir, hparams, summary_writer) dev_scores, test_scores, _ = run_external_eval( infer_model, infer_sess, model_dir, hparams, summary_writer) # if global_step - last_external_eval_step >= steps_per_external_eval: # last_external_eval_step = global_step # # Save checkpoint # loaded_train_model.saver.save( # train_sess, # os.path.join(out_dir, "translate.ckpt"), # global_step=global_step) # run_sample_decode(infer_model, infer_sess, # model_dir, hparams, summary_writer, sample_src_data, # sample_tgt_data) # dev_scores, test_scores, _ = run_external_eval( # infer_model, infer_sess, model_dir, # hparams, summary_writer) # Done training loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "translate.ckpt"), global_step=global_step) result_summary, _, dev_scores, test_scores, dev_ppl, test_ppl = run_full_eval( model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data) utils.print_out( "# Final, step %d lr %g " "step-time %.2f wps %.2fK ppl %.2f, %s, %s" % (global_step, loaded_train_model.learning_rate.eval(session=train_sess), avg_step_time, speed, train_ppl, result_summary, time.ctime()), log_f) utils.print_time("# Done training!", start_train_time) utils.print_out("# Start evaluating saved best models.") for metric in hparams.metrics: best_model_dir = getattr(hparams, "best_" + metric + "_dir") result_summary, best_global_step, _, _, _, _ = run_full_eval( best_model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data) utils.print_out( "# Best %s, step %d " "step-time %.2f wps %.2fK, %s, %s" % (metric, best_global_step, avg_step_time, speed, result_summary, time.ctime()), log_f) summary_writer.close() return (dev_scores, test_scores, dev_ppl, test_ppl, global_step)
def main(_): default_hparams = nmt.create_hparams(FLAGS) ## Train / Decode out_dir = FLAGS.out_dir if not tf.gfile.Exists(out_dir): tf.gfile.MakeDirs(out_dir) # Load hparams. hparams = nmt.create_or_load_hparams(out_dir, default_hparams, FLAGS.hparams_path, save_hparams=False) log_device_placement = hparams.log_device_placement out_dir = hparams.out_dir num_train_steps = hparams.num_train_steps steps_per_stats = hparams.steps_per_stats avg_ckpts = hparams.avg_ckpts if not hparams.attention: model_creator = nmt_model.Model else: # Attention if (hparams.encoder_type == "gnmt" or hparams.attention_architecture in ["gnmt", "gnmt_v2"]): model_creator = gnmt_model.GNMTModel elif hparams.attention_architecture == "standard": model_creator = attention_model.AttentionModel else: raise ValueError("Unknown attention architecture %s" % hparams.attention_architecture) train_model =\ model_helper.create_train_model(model_creator, hparams, scope=None) config_proto = utils.get_config_proto( log_device_placement=log_device_placement, num_intra_threads=1, num_inter_threads=36) def run(train_sess, num_workers, worker_id, num_replicas_per_worker): # Random random_seed = FLAGS.random_seed if random_seed is not None and random_seed > 0: utils.print_out("# Set random seed to %d" % random_seed) random.seed(random_seed + worker_id) np.random.seed(random_seed + worker_id) # Log and output files log_file = os.path.join(out_dir, "log_%d" % time.time()) log_f = tf.gfile.GFile(log_file, mode="a") utils.print_out("# log_file=%s" % log_file, log_f) global_step = train_sess.run(train_model.model.global_step)[0] last_stats_step = global_step # This is the training loop. stats, info, start_train_time = before_train(train_model, train_sess, global_step, hparams, log_f, num_replicas_per_worker) epoch_steps = FLAGS.epoch_size / (FLAGS.batch_size * num_workers * num_replicas_per_worker) for i in range(FLAGS.max_steps): ### Run a step ### start_time = time.time() if hparams.epoch_step != 0 and hparams.epoch_step % epoch_steps == 0: hparams.epoch_step = 0 skip_count = train_model.skip_count_placeholder feed_dict = {} feed_dict[skip_count] = [ 0 for i in range(num_replicas_per_worker) ] init = train_model.iterator.initializer train_sess.run(init, feed_dict=feed_dict) if worker_id == 0: results = train_sess.run([ train_model.model.update, train_model.model.train_loss, train_model.model.predict_count, train_model.model.train_summary, train_model.model.global_step, train_model.model.word_count, train_model.model.batch_size, train_model.model.grad_norm, train_model.model.learning_rate ]) step_result = [r[0] for r in results] else: global_step, _ = train_sess.run( [train_model.model.global_step, train_model.model.update]) hparams.epoch_step += 1 if worker_id == 0: # Process step_result, accumulate stats, and write summary global_step, info["learning_rate"], step_summary = \ train.update_stats(stats, start_time, step_result) # Once in a while, we print statistics. if global_step - last_stats_step >= steps_per_stats: last_stats_step = global_step is_overflow = train.process_stats(stats, info, global_step, steps_per_stats, log_f) train.print_step_info(" ", global_step, info, train._get_best_results(hparams), log_f) if is_overflow: break # Reset statistics stats = train.init_stats() sess, num_workers, worker_id, num_replicas_per_worker = \ parallax.parallel_run(train_model.graph, FLAGS.resource_info_file, sync=FLAGS.sync, parallax_config=parallax_config.build_config()) run(sess, num_workers, worker_id, num_replicas_per_worker)
def train(hparams, scope=None): model_dir = hparams.out_dir avg_ckpts = hparams.avg_ckpts steps_per_stats = hparams.steps_per_stats steps_per_external_eval = hparams.steps_per_external_eval steps_per_eval = 10 * steps_per_stats if not steps_per_external_eval: steps_per_external_eval = 5 * steps_per_eval summary_name = "summary" model_creator = gnmt_model.GNMTModel train_model = model_helper.create_train_model(model_creator, hparams) eval_model = model_helper.create_eval_model(model_creator, hparams) infer_model = model_helper.create_infer_model(model_creator, hparams) config_proto = tf.ConfigProto() config_proto.gpu_options.allow_growth = True train_sess = tf.Session(graph=train_model.graph, config=config_proto) eval_sess = tf.Session(graph=eval_model.graph, config=config_proto) infer_sess = tf.Session(graph=infer_model.graph, config=config_proto) with train_model.graph.as_default(): loaded_train_model, global_step = model_helper.create_or_load_model( train_model.model, model_dir, train_sess, "train") # Summary writer summary_writer = tf.summary.FileWriter( os.path.join(model_dir, summary_name), train_model.graph) # Preload data for sample decoding. dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src) dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt) sample_src_data = utils.load_data(dev_src_file) sample_tgt_data = utils.load_data(dev_tgt_file) # First evaluation result_summary, _, _ = run_full_eval(model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data, avg_ckpts) utils.log('First evaluation: {}'.format(result_summary)) last_stats_step = global_step last_eval_step = global_step last_external_eval_step = global_step # This is the training loop. stats = init_stats() info = { "train_ppl": 0.0, "speed": 0.0, "avg_step_time": 0.0, "avg_grad_norm": 0.0, "learning_rate": loaded_train_model.learning_rate.eval(session=train_sess) } utils.log("Start step %d, lr %g" % (global_step, info["learning_rate"])) # Initialize all of the iterators train_sess.run(train_model.iterator.initializer) epoch = 1 while True: ### Run a step ### start_time = time.time() try: step_result = loaded_train_model.train(train_sess) except tf.errors.OutOfRangeError: # Finished going through the training dataset. Go to next epoch. utils.log( "Finished epoch %d, step %d. Perform external evaluation" % (epoch, global_step)) run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, sample_src_data, sample_tgt_data) run_external_eval(infer_model, infer_sess, model_dir, hparams, summary_writer) if avg_ckpts: run_avg_external_eval(infer_model, infer_sess, model_dir, hparams, summary_writer, global_step) train_sess.run(train_model.iterator.initializer) if epoch < hparams.epochs: epoch += 1 continue else: break # Process step_result, accumulate stats, and write summary global_step, info["learning_rate"], step_summary = update_stats( stats, start_time, step_result) summary_writer.add_summary(step_summary, global_step) # Once in a while, we print statistics. if global_step - last_stats_step >= steps_per_stats: last_stats_step = global_step is_overflow = process_stats(stats, info, global_step, steps_per_stats) print_step_info(" ", global_step, info, "BLEU %.2f" % (hparams.best_bleu, )) if is_overflow: break # Reset statistics stats = init_stats() if global_step - last_eval_step >= steps_per_eval: last_eval_step = global_step utils.log("Save eval, global step %d" % (global_step, )) utils.add_summary(summary_writer, global_step, "train_ppl", info["train_ppl"]) # Save checkpoint loaded_train_model.saver.save(train_sess, os.path.join(model_dir, "translate.ckpt"), global_step=global_step) # Evaluate on dev/test run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, sample_src_data, sample_tgt_data) run_internal_eval(eval_model, eval_sess, model_dir, hparams, summary_writer) if global_step - last_external_eval_step >= steps_per_external_eval: last_external_eval_step = global_step # Save checkpoint loaded_train_model.saver.save(train_sess, os.path.join(model_dir, "translate.ckpt"), global_step=global_step) run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, sample_src_data, sample_tgt_data) run_external_eval(infer_model, infer_sess, model_dir, hparams, summary_writer) if avg_ckpts: run_avg_external_eval(infer_model, infer_sess, model_dir, hparams, summary_writer, global_step) # Done training loaded_train_model.saver.save(train_sess, os.path.join(model_dir, "translate.ckpt"), global_step=global_step) (result_summary, _, final_eval_metrics) = run_full_eval( model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data, avg_ckpts) print_step_info("Final, ", global_step, info, result_summary) utils.log("Done training!") summary_writer.close() utils.log("Start evaluating saved best models.") best_model_dir = hparams.best_bleu_dir summary_writer = tf.summary.FileWriter( os.path.join(best_model_dir, summary_name), infer_model.graph) result_summary, best_global_step, _ = run_full_eval( best_model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data) print_step_info("Best BLEU, ", best_global_step, info, result_summary) summary_writer.close() if avg_ckpts: best_model_dir = hparams.avg_best_bleu_dir summary_writer = tf.summary.FileWriter( os.path.join(best_model_dir, summary_name), infer_model.graph) result_summary, best_global_step, _ = run_full_eval( best_model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data) print_step_info("Averaged Best BLEU, ", best_global_step, info, result_summary) summary_writer.close() return final_eval_metrics, global_step
def train(hparams): num_epochs = hparams.num_epochs num_ckpt_epochs = hparams.num_ckpt_epochs summary_name = "train_log" out_dir = hparams.out_dir model_dir = out_dir log_device_placement = hparams.log_device_placement input_emb_weights = np.loadtxt( hparams.input_emb_file, delimiter=' ') if hparams.input_emb_file else None if hparams.model_architecture == "rnn-model": model_creator = model.RNN else: raise ValueError( "Unknown model architecture. Only simple_rnn is supported so far.") #create 2 models in 2 graphs for train and evaluation, with 2 sessions sharing the same variables. train_model = model_helper.create_train_model( model_creator, hparams, hparams.train_input_path, hparams.train_target_path, mode=tf.contrib.learn.ModeKeys.TRAIN) eval_model = model_helper.create_eval_model(model_creator, hparams, tf.contrib.learn.ModeKeys.EVAL) # some configuration of gpus logging config_proto = utils.get_config_proto( log_device_placement=log_device_placement, allow_soft_placement=True) # create two separate sessions for trai/eval train_sess = tf.Session(config=config_proto, graph=train_model.graph) eval_sess = tf.Session(config=config_proto, graph=eval_model.graph) # create a new train model by initializing all variables of the train graph in the train_sess. # or, using the latest checkpoint in the model_dir, load all variables of the train graph in the train_sess. # Note that at this point, the eval graph variables are not initialized. with train_model.graph.as_default(): loaded_train_model = model_helper.create_or_load_model( train_model.model, train_sess, "train", model_dir, input_emb_weights) # create a log file with name summary_name in out_dir. The file is written asynchronously during the training process. # We also passed the train graph in order to be able to display it in Tensorboard summary_writer = tf.summary.FileWriter(os.path.join(out_dir, summary_name), train_model.graph) #run first evaluation before starting training dev_loss = run_evaluation(eval_model, eval_sess, model_dir, hparams.val_input_path, hparams.val_target_path, input_emb_weights, summary_writer) train_loss = run_evaluation(eval_model, eval_sess, model_dir, hparams.train_input_path, hparams.train_target_path, input_emb_weights, summary_writer) print("Dev loss before training: %.3f" % dev_loss) print("Train loss before training: %.3f" % train_loss) # Start training start_train_time = time.time() epoch_time = 0.0 batch_loss, epoch_loss = 0.0, 0.0 batch_count = 0.0 #initialize train iterator in train_sess train_sess.run(train_model.iterator.initializer) #keep lists of train/val losses for all epochs train_losses = [] dev_losses = [] #train the model for num_epochs. One epoch means a pass through the whole train dataset, i.e., through all the batches. for epoch in range(num_epochs): #go through all batches for the current epoch while True: start_batch_time = 0.0 try: # this call will run operations of train graph in train_sess step_result = loaded_train_model.train(train_sess) (_, batch_loss, batch_summary, global_step, learning_rate, batch_size, inputs, targets) = step_result epoch_time += (time.time() - start_batch_time) epoch_loss += batch_loss batch_count += 1 except tf.errors.OutOfRangeError: #when the iterator of the train batches reaches the end, break the loop #and reinitialize the iterator to start from the beginning of the train data. train_sess.run(train_model.iterator.initializer) break # average epoch loss and epoch time over batches epoch_loss /= batch_count epoch_time /= batch_count batch_count = 0.0 #print results if the current epoch is a print results epoch if (epoch + 1) % num_ckpt_epochs == 0: print("Saving checkpoint...") model_helper.add_summary(summary_writer, "train_loss", epoch_loss) # save checkpoint. We save the values of the variables of the train graph. # train_sess is the session in which the train graph was launched. # global_step parameter is optional and is appended to the name of the checkpoint. loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "rnn.ckpt"), global_step=epoch) print("Results: ") dev_loss = run_evaluation(eval_model, eval_sess, model_dir, hparams.val_input_path, hparams.val_target_path, input_emb_weights, summary_writer) # tr_loss = run_evaluation(eval_model, eval_sess, model_dir, hparams.train_input_path, hparams.train_target_path, input_emb_weights, summary_writer) # print("check %.3f:"%tr_loss) print(" epoch %d lr %g " "train_loss %.3f, dev_loss %.3f" % (epoch, loaded_train_model.learning_rate.eval(session=train_sess), epoch_loss, dev_loss)) train_losses.append(epoch_loss) dev_losses.append(dev_loss) # save final model loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "rnn.ckpt"), global_step=num_epochs) print("Done training in %.2fK" % (time.time() - start_train_time)) min_dev_loss = np.min(dev_losses) min_dev_idx = np.argmin(dev_losses) print("Min val loss: %.3f at epoch %d" % (min_dev_loss, min_dev_idx)) summary_writer.close()
def train(hparams, scope=None, target_session=""): """Train a translation model.""" log_device_placement = hparams.log_device_placement out_dir = hparams.out_dir num_train_steps = hparams.num_train_steps steps_per_stats = hparams.steps_per_stats steps_per_external_eval = hparams.steps_per_external_eval steps_per_eval = 10 * steps_per_stats if not steps_per_external_eval: steps_per_external_eval = 5 * steps_per_eval if not hparams.attention: model_creator = nmt_model.Model elif hparams.attention_architecture == "standard": model_creator = attention_model.AttentionModel elif hparams.attention_architecture in ["gnmt", "gnmt_v2"]: model_creator = gnmt_model.GNMTModel else: raise ValueError("Unknown model architecture") train_model = model_helper.create_train_model(model_creator, hparams, scope) eval_model = model_helper.create_eval_model(model_creator, hparams, scope) infer_model = model_helper.create_infer_model(model_creator, hparams, scope) # Preload data for sample decoding. dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src) dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt) sample_src_data = inference.load_data(dev_src_file) sample_tgt_data = inference.load_data(dev_tgt_file) summary_name = "train_log" model_dir = hparams.out_dir # Log and output files log_file = os.path.join(out_dir, "log_%d" % time.time()) log_f = tf.gfile.GFile(log_file, mode="a") utils.print_out("# log_file=%s" % log_file, log_f) avg_step_time = 0.0 # TensorFlow model config_proto = utils.get_config_proto( log_device_placement=log_device_placement) train_sess = tf.Session( target=target_session, config=config_proto, graph=train_model.graph) eval_sess = tf.Session( target=target_session, config=config_proto, graph=eval_model.graph) infer_sess = tf.Session( target=target_session, config=config_proto, graph=infer_model.graph) with train_model.graph.as_default(): loaded_train_model, global_step = model_helper.create_or_load_model( train_model.model, model_dir, train_sess, "train") # Summary writer summary_writer = tf.summary.FileWriter( os.path.join(out_dir, summary_name), train_model.graph) # First evaluation run_full_eval( model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data) last_stats_step = global_step last_eval_step = global_step last_external_eval_step = global_step # This is the training loop. step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0 checkpoint_total_count = 0.0 speed, train_ppl = 0.0, 0.0 start_train_time = time.time() utils.print_out( "# Start step %d, lr %g, %s" % (global_step, loaded_train_model.learning_rate.eval(session=train_sess), time.ctime()), log_f) # Initialize all of the iterators skip_count = hparams.batch_size * hparams.epoch_step utils.print_out("# Init train iterator, skipping %d elements" % skip_count) train_sess.run( train_model.iterator.initializer, feed_dict={train_model.skip_count_placeholder: skip_count}) while global_step < num_train_steps: ### Run a step ### start_time = time.time() try: step_result = loaded_train_model.train(train_sess) (_, step_loss, step_predict_count, step_summary, global_step, step_word_count, batch_size) = step_result hparams.epoch_step += 1 except tf.errors.OutOfRangeError: # Finished going through the training dataset. Go to next epoch. hparams.epoch_step = 0 utils.print_out( "# Finished an epoch, step %d. Perform external evaluation" % global_step) run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, sample_src_data, sample_tgt_data) dev_scores, test_scores, _ = run_external_eval( infer_model, infer_sess, model_dir, hparams, summary_writer) train_sess.run( train_model.iterator.initializer, feed_dict={train_model.skip_count_placeholder: 0}) continue # Write step summary. summary_writer.add_summary(step_summary, global_step) # update statistics step_time += (time.time() - start_time) checkpoint_loss += (step_loss * batch_size) checkpoint_predict_count += step_predict_count checkpoint_total_count += float(step_word_count) # Once in a while, we print statistics. if global_step - last_stats_step >= steps_per_stats: last_stats_step = global_step # Print statistics for the previous epoch. avg_step_time = step_time / steps_per_stats train_ppl = utils.safe_exp(checkpoint_loss / checkpoint_predict_count) speed = checkpoint_total_count / (1000 * step_time) utils.print_out( " global step %d lr %g " "step-time %.2fs wps %.2fK ppl %.2f %s" % (global_step, loaded_train_model.learning_rate.eval(session=train_sess), avg_step_time, speed, train_ppl, _get_best_results(hparams)), log_f) if math.isnan(train_ppl): break # Reset timer and loss. step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0 checkpoint_total_count = 0.0 if global_step - last_eval_step >= steps_per_eval: last_eval_step = global_step utils.print_out("# Save eval, global step %d" % global_step) utils.add_summary(summary_writer, global_step, "train_ppl", train_ppl) # Save checkpoint loaded_train_model.saver.save( train_sess, os.path.join(out_dir, "translate.ckpt"), global_step=global_step) # Evaluate on dev/test run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, sample_src_data, sample_tgt_data) dev_ppl, test_ppl = run_internal_eval( eval_model, eval_sess, model_dir, hparams, summary_writer) if global_step - last_external_eval_step >= steps_per_external_eval: last_external_eval_step = global_step # Save checkpoint loaded_train_model.saver.save( train_sess, os.path.join(out_dir, "translate.ckpt"), global_step=global_step) run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, sample_src_data, sample_tgt_data) dev_scores, test_scores, _ = run_external_eval( infer_model, infer_sess, model_dir, hparams, summary_writer) # Done training loaded_train_model.saver.save( train_sess, os.path.join(out_dir, "translate.ckpt"), global_step=global_step) result_summary, _, dev_scores, test_scores, dev_ppl, test_ppl = run_full_eval( model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data) utils.print_out( "# Final, step %d lr %g " "step-time %.2f wps %.2fK ppl %.2f, %s, %s" % (global_step, loaded_train_model.learning_rate.eval(session=train_sess), avg_step_time, speed, train_ppl, result_summary, time.ctime()), log_f) utils.print_time("# Done training!", start_train_time) summary_writer.close() utils.print_out("# Start evaluating saved best models.") for metric in hparams.metrics: best_model_dir = getattr(hparams, "best_" + metric + "_dir") summary_writer = tf.summary.FileWriter( os.path.join(best_model_dir, summary_name), infer_model.graph) result_summary, best_global_step, _, _, _, _ = run_full_eval( best_model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data) utils.print_out("# Best %s, step %d " "step-time %.2f wps %.2fK, %s, %s" % (metric, best_global_step, avg_step_time, speed, result_summary, time.ctime()), log_f) summary_writer.close() return (dev_scores, test_scores, dev_ppl, test_ppl, global_step)
def train(hparams, ckpt_dir, scope=None, target_session="", alternative=False): out_dir = os.path.join(hparams.base_dir, "train") if not misc.check_file_existence(out_dir): tf.gfile.MakeDirs(out_dir) tf.logging.info("All train relevant results will be put in %s" % out_dir) num_train_steps = hparams.num_train_steps steps_per_stats = hparams.steps_per_stats ckpt_path = os.path.join(ckpt_dir, "model.ckpt") # Create model model_creator = get_model_creator(hparams.model_type) train_model = helper.create_train_model(model_creator, hparams, scope) config_proto = misc.get_config_proto( log_device_placement=False, allow_soft_placement=True, num_intra_threads=hparams.num_intra_threads, num_inter_threads=hparams.num_inter_threads) train_sess = tf.Session(target=target_session, config=config_proto, graph=train_model.graph) tf.logging.info("Create model successfully") with train_model.graph.as_default(): loaded_train_model, global_step = helper.create_or_load_model( train_model.model, ckpt_dir, train_sess, "train") # Summary writer summary_name = "train_summary" summary_path = os.path.join(out_dir, summary_name) if not tf.gfile.Exists(summary_path): tf.gfile.MakeDirs(summary_path) summary_writer = tf.summary.FileWriter(summary_path, train_model.graph) last_stats_step = global_step # Training iteration stats, info, start_train_time = before_train(loaded_train_model, train_model, train_sess, global_step) tf.logging.info("Ready to train") epoch_step = 0 while global_step < num_train_steps: start_time = time.time() try: tf.logging.info("Start train epoch:%d" % epoch_step) step_result = loaded_train_model.train(train_sess) epoch_step += 1 except tf.errors.OutOfRangeError: tf.logging.info("Saving epoch step %d model into checkpoint" % epoch_step) loaded_train_model.saver.save(train_sess, ckpt_path, global_step=global_step) # Training while evaluating alternately if alternative: eval.evaluate(hparams, scope=scope, target_session=target_session, global_step_=global_step, ckpt_path=ckpt_path, alternative=alternative) # Finished going through the training dataset. Go to next epoch. epoch_step = 0 tf.logging.info("# Finished an epoch, step %d." % global_step) train_sess.run(train_model.data_wrapper.initializer) continue global_step, info["learning_rate"], step_summary = update_stats( stats, start_time, step_result) summary_writer.add_summary(step_summary, global_step) # Once in a while, print statistics. if global_step - last_stats_step >= steps_per_stats: tf.logging.info( "In global step:%d, time to print train statistics" % global_step) last_stats_step = global_step # Update info process_stats(stats, info, steps_per_stats, hparams.img_batch_size) print_step_info(global_step, info, stats) # Reset statistic stats = init_stats() # Done training tf.logging.info("Finish training, saving model into checkpoint") loaded_train_model.saver.save(train_sess, ckpt_path, global_step=global_step) summary_writer.close()
def train(hparams): """Train a sequence tagging model.""" num_epochs = hparams.num_epochs num_ckpt_epochs = hparams.num_ckpt_epochs summary_name = "train_log" out_dir = hparams.out_dir model_dir = out_dir log_device_placement = hparams.log_device_placement # Load external embedding vectors if a file is given as input. You dont care about external embeddings now. input_emb_weights = np.loadtxt( hparams.input_emb_file, delimiter=' ') if hparams.input_emb_file else None if hparams.model_architecture == "simple_rnn": model_creator = model.RNN else: raise ValueError( "Unknown model architecture. Only simple_rnn is supported so far.") # create 2 models in 2 separate graphs for train and evaluation. train_model = model_helper.create_train_model( model_creator, hparams, hparams.train_input_path, hparams.train_target_path, mode=tf.contrib.learn.ModeKeys.TRAIN) eval_model = model_helper.create_eval_model(model_creator, hparams, tf.contrib.learn.ModeKeys.EVAL) # some configuration of gpus logging config_proto = utils.get_config_proto( log_device_placement=log_device_placement, allow_soft_placement=True) # create two separate sessions for train/evaluation. train_sess = tf.Session(config=config_proto, graph=train_model.graph) eval_sess = tf.Session(config=config_proto, graph=eval_model.graph) # create a new train model by initializing all variables of the train graph in the train_sess. # or, using the latest checkpoint in the model_dir, load all variables of the train graph in the train_sess. # Note that at this point, the eval graph variables are not initialized. with train_model.graph.as_default(): loaded_train_model = model_helper.create_or_load_model( train_model.model, train_sess, "train", model_dir, input_emb_weights) # create a log file with name summary_name in the out_dir. The file is written asynchronously during the training process. # We also passed the train graph in order to be able to display it in Tensorboard. summary_writer = tf.summary.FileWriter(os.path.join(out_dir, summary_name), train_model.graph) # run first evaluation before starting training. val_loss, val_acc = run_evaluation(eval_model, eval_sess, model_dir, hparams.val_input_path, hparams.val_target_path, input_emb_weights, summary_writer) train_loss, train_acc = run_evaluation(eval_model, eval_sess, model_dir, hparams.train_input_path, hparams.train_target_path, input_emb_weights, summary_writer) print("Before training: Val loss %.3f, Val accuracy %.3f." % (val_loss, val_acc)) print("Before training: Train loss %.3f Train acc %.3f" % (train_loss, train_acc)) # Start training start_train_time = time.time() avg_batch_time = 0.0 batch_loss, epoch_loss, epoch_accuracy = 0.0, 0.0, 0.0 batch_count = 0.0 # initialize train iterator in train_sess train_sess.run(train_model.iterator.initializer) # keep lists of train/val losses for all epochs. train_losses = [] dev_losses = [] # vars to compute timeline of operations. Timeline is useful to see how much time each operator on tf graph takes. # You dont care about this. options = None run_metadata = None if hparams.timeline: options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() # train the model for num_epochs. One epoch means a pass through the whole train dataset, i.e., through all the batches. step = 0 for epoch in range(num_epochs): # go through all batches for the current epoch. while True: start_batch_time = time.time() try: # You dont care about timeline now. if hparams.timeline and step % 10 == 0: # this call will run operations of train graph in train_sess. step_result = loaded_train_model.train( train_sess, options=options, run_metadata=run_metadata) summary_writer.add_run_metadata(run_metadata, 'step%d' % step) fetched_timeline = timeline.Timeline( run_metadata.step_stats) chrome_trace = fetched_timeline.generate_chrome_trace_format( ) if not tf.gfile.Exists( os.path.join( out_dir, 'timelines/timeline_02_step_%d.json')): tf.gfile.MakeDirs(out_dir) with open( os.path.join(out_dir, 'timelines/timeline_02_step_%d.json') % step, 'w') as f: f.write(chrome_trace) else: # this call will run operations of train graph in train_sess. step_result = loaded_train_model.train(train_sess, options=None, run_metadata=None) (_, batch_loss, batch_summary, global_step, learning_rate, batch_size, batch_accuracy) = step_result avg_batch_time += (time.time() - start_batch_time) epoch_loss += batch_loss epoch_accuracy += batch_accuracy batch_count += 1 step += 1 except tf.errors.OutOfRangeError: # We went through all train batches and so, the iterator over the train batches reached the end. # We break the while loop and reinitialize the iterator to start from the beginning of the train data. train_sess.run(train_model.iterator.initializer) break # average epoch loss and epoch time over batches. epoch_loss /= batch_count avg_batch_time /= batch_count epoch_accuracy /= batch_count print("Number of batches: %d" % batch_count) # print results if the current epoch is a print results epoch if (epoch + 1) % num_ckpt_epochs == 0: print("Saving checkpoint...") model_helper.add_summary(summary_writer, "train_loss", epoch_loss) model_helper.add_summary(summary_writer, "train_accuracy", epoch_accuracy) # save checkpoint. We save the values of the variables of the train graph. # train_sess is the session in which the train graph was launched. # global_step parameter is optional and is appended to the name of the checkpoint. loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "rnn.ckpt"), global_step=epoch) print("Results: ") val_loss, val_accuracy = run_evaluation( eval_model, eval_sess, model_dir, hparams.val_input_path, hparams.val_target_path, input_emb_weights, summary_writer) print( " epoch %d lr %g " "train_loss %.3f, val_loss %.3f, train_accuracy %.3f, val accuracy %.3f, avg_batch_time %f" % (epoch, loaded_train_model.learning_rate.eval(session=train_sess), epoch_loss, val_loss, epoch_accuracy, val_accuracy, avg_batch_time)) train_losses.append(epoch_loss) dev_losses.append(val_loss) batch_count = 0.0 avg_batch_time = 0.0 epoch_loss = 0.0 epoch_accuracy = 0.0 # save final model loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "rnn.ckpt"), global_step=num_epochs) print("Done training in %.2fK" % (time.time() - start_train_time)) min_dev_loss = np.min(dev_losses) min_dev_idx = np.argmin(dev_losses) print("Min val loss: %.3f at epoch %d" % (min_dev_loss, min_dev_idx)) summary_writer.close()