def process_stats(stats, info, global_step, steps_per_stats, log_f): """Update info and check for overflow.""" # Per-step info info["avg_step_time"] = stats["step_time"] / steps_per_stats info["avg_grad_norm"] = stats["grad_norm"] / steps_per_stats info["avg_sequence_count"] = stats["sequence_count"] / steps_per_stats info["speed"] = stats["word_count"] / (1000 * stats["step_time"]) # Per-predict info info["train_ppl"] = (utils.safe_exp(stats["train_loss"] / stats["predict_count"])) # Check for overflow is_overflow = False train_ppl = info["train_ppl"] if math.isnan(train_ppl) or math.isinf(train_ppl) or train_ppl > 1e20: utils.print_out(" step %d overflow, stop early" % global_step, log_f) is_overflow = True return is_overflow
def compute_perplexity(self, sess, name): """ Compute perplexity of the output of the model. """ total_loss = 0 total_predict_count = 0 while True: try: loss, predict_count, batch_size = self.eval(sess) total_loss += loss * batch_size total_predict_count += predict_count except tf.errors.OutOfRangeError: break perplexity = utils.safe_exp(total_loss / total_predict_count) utils.log("%s perplexity: %.2f" % (name, perplexity)) return perplexity
def check_stats(stats, global_step, steps_per_stats, hparams, log_f): """Print statistics and also check for overflow.""" # Print statistics for the previous epoch. avg_step_time = stats["step_time"] / steps_per_stats avg_grad_norm = stats["grad_norm"] / steps_per_stats train_ppl = utils.safe_exp(stats["loss"] / stats["predict_count"]) speed = stats["total_count"] / (1000 * stats["step_time"]) utils.print_out( " global step %d lr %g " "step-time %.2fs wps %.2fK ppl %.2f gN %.2f %s" % (global_step, stats["learning_rate"], avg_step_time, speed, train_ppl, avg_grad_norm, _get_best_results(hparams)), log_f) # Check for overflow is_overflow = False if math.isnan(train_ppl) or math.isinf(train_ppl) or train_ppl > 1e20: utils.print_out(" step %d overflow, stop early" % global_step, log_f) is_overflow = True return is_overflow
def compute_perplexity(model, sess, name, eval_handle): """Compute perplexity of the output of the model based on loss function.""" def aggregate_all_summaries(original, updates): for key in updates: if key not in original: original[key] = 0.0 original[key] += updates[key] return original total_loss = 0 total_predict_count = 0 start_time = time.time() aggregated_summaries = {} batch_processed = 0 while True: try: loss, all_summaries, predict_count, batch_size = model.eval( sess, eval_handle) total_loss += loss * batch_size batch_processed += 1 total_predict_count += predict_count aggregated_summaries = aggregate_all_summaries( aggregated_summaries, all_summaries) except tf.errors.OutOfRangeError: break perplexity = utils.safe_exp(total_loss / total_predict_count) for key in aggregated_summaries: if key not in set([ "eval_dialogue_loss1", "eval_dialogue_loss2", "eval_action_loss3" ]): aggregated_summaries[key] /= batch_processed utils.print_time(" eval %s: perplexity %.2f" % (name, perplexity), start_time) return perplexity, aggregated_summaries
def train(hparams, identity, scope=None, target_session=""): """main loop to train the dialogue model. identity is used.""" out_dir = hparams.out_dir steps_per_stats = hparams.steps_per_stats steps_per_internal_eval = 3 * steps_per_stats model_creator = diag_model.Model train_model = model_helper.create_train_model(model_creator, hparams, scope) model_dir = hparams.out_dir # Log and output files log_file = os.path.join(out_dir, identity+"log_%d" % time.time()) log_f = tf.gfile.GFile(log_file, mode="a") utils.print_out("# log_file=%s" % log_file, log_f) avg_step_time = 0.0 # load TensorFlow session and model config_proto = utils.get_config_proto( log_device_placement=hparams.log_device_placement, allow_soft_placement=True) train_sess = tf.Session( target=target_session, config=config_proto, graph=train_model.graph) train_handle = train_sess.run(train_model.train_iterator.string_handle()) with train_model.graph.as_default(): loaded_train_model, global_step = model_helper.create_or_load_model( train_model.model, model_dir, train_sess, "train") # initialize summary writer summary_writer = tf.summary.FileWriter( os.path.join(out_dir, "train_log"), train_model.graph) last_stats_step = global_step last_eval_step = global_step # initialize training stats. step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0 checkpoint_total_count = 0.0 speed, train_ppl = 0.0, 0.0 start_train_time = time.time() utils.print_out( "# Start step %d, lr %g, %s" % (global_step, loaded_train_model.learning_rate.eval(session=train_sess), time.ctime()), log_f) # initialize iterators skip_count = hparams.batch_size * hparams.epoch_step utils.print_out("# Init train iterator, skipping %d elements" % skip_count) train_sess.run( train_model.train_iterator.initializer, feed_dict={train_model.skip_count_placeholder: skip_count}) # main training loop while global_step < hparams.num_train_steps: start_time = time.time() try: # run a step step_result = loaded_train_model.train(train_sess, train_handle) (_, step_loss, all_summaries, step_predict_count, step_summary, global_step, step_word_count, batch_size, _, _, words1, words2, mask1, mask2) = step_result hparams.epoch_step += 1 except tf.errors.OutOfRangeError: # finished an epoch hparams.epoch_step = 0 utils.print_out("# Finished an epoch, step %d." % global_step) train_sess.run( train_model.train_iterator.initializer, feed_dict={train_model.skip_count_placeholder: 0}) continue # Write step summary. summary_writer.add_summary(step_summary, global_step) for key in all_summaries: utils.add_summary(summary_writer, global_step, key, all_summaries[key]) # update statistics step_time += (time.time() - start_time) checkpoint_loss += (step_loss * batch_size) checkpoint_predict_count += step_predict_count checkpoint_total_count += float(step_word_count) if global_step - last_stats_step >= steps_per_stats: # print statistics for the previous epoch and save the model. last_stats_step = global_step avg_step_time = step_time / steps_per_stats utils.add_summary(summary_writer, global_step, "step_time", avg_step_time) train_ppl = utils.safe_exp(checkpoint_loss / checkpoint_predict_count) speed = checkpoint_total_count / (1000 * step_time) if math.isnan(train_ppl): break # Reset timer and loss. step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0 checkpoint_total_count = 0.0 # save the model loaded_train_model.saver.save( train_sess, os.path.join(out_dir, "dialogue.ckpt"), global_step=global_step) # print the dialogue if in debug mode if hparams.debug: utils.print_current_dialogue(words1, words2, mask1, mask2) # write out internal evaluation if global_step - last_eval_step >= steps_per_internal_eval: last_eval_step = global_step utils.print_out("# Internal Evaluation. global step %d" % global_step) utils.add_summary(summary_writer, global_step, "train_ppl", train_ppl) # finished training loaded_train_model.saver.save( train_sess, os.path.join(out_dir, "dialogue.ckpt"), global_step=global_step) result_summary = "" utils.print_out( "# Final, step %d lr %g " "step-time %.2f wps %.2fK ppl %.2f, %s, %s" % (global_step, loaded_train_model.learning_rate.eval(session=train_sess), avg_step_time, speed, train_ppl, result_summary, time.ctime()), log_f) utils.print_time("# Done training!", start_train_time) utils.print_out("# Start evaluating saved best models.") summary_writer.close()
def train(hparams, scope=None, target_session=''): """Train the chatbot""" # Initialize some local hyperparameters log_device_placement = hparams.log_device_placement out_dir = hparams.out_dir num_train_steps = hparams.num_train_steps steps_per_stats = hparams.steps_per_stats steps_per_external_eval = hparams.steps_per_external_eval steps_per_eval = 10 * steps_per_stats if not steps_per_external_eval: steps_per_external_eval = 5 * steps_per_eval if hparams.architecture == "simple": model_creator = SimpleModel get_infer_iterator = iterator_utils.get_infer_iterator get_iterator = iterator_utils.get_iterator elif hparams.architecture == "hier": model_creator = HierarchicalModel # Parse some of the arguments now def curry_get_infer_iterator(dataset, vocab_table, batch_size, src_reverse, eos, src_max_len): return end2end_iterator_utils.get_infer_iterator(dataset, vocab_table, batch_size, src_reverse, eos, src_max_len=src_max_len, eou=hparams.eou, dialogue_max_len=hparams.dialogue_max_len) get_infer_iterator = curry_get_infer_iterator def curry_get_iterator(src_dataset, tgt_dataset, vocab_table, batch_size, sos, eos, src_reverse, random_seed, num_buckets, src_max_len=None, tgt_max_len=None, num_threads=4, output_buffer_size=None, skip_count=None): return end2end_iterator_utils.get_iterator(src_dataset, tgt_dataset, vocab_table, batch_size, sos, eos, eou=hparams.eou, src_reverse=src_reverse, random_seed=random_seed, num_dialogue_buckets=num_buckets, src_max_len=src_max_len, tgt_max_len=tgt_max_len, num_threads=num_threads, output_buffer_size=output_buffer_size, skip_count=skip_count) get_iterator = curry_get_iterator else: raise ValueError("Unkown architecture", hparams.architecture) # Create three models which share parameters through the use of checkpoints train_model = create_train_model(model_creator, get_iterator, hparams, scope) eval_model = create_eval_model(model_creator, get_iterator, hparams, scope) infer_model = inference.create_infer_model(model_creator, get_infer_iterator, hparams, scope) # ToDo: adapt for architectures # Preload the data to use for sample decoding dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src) dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt) sample_src_data = inference.load_data(dev_src_file) sample_tgt_data = inference.load_data(dev_tgt_file) summary_name = "train_log" model_dir = hparams.out_dir # Log and output files log_file = os.path.join(out_dir, "log_%d" % time.time()) log_f = tf.gfile.GFile(log_file, mode="a") utils.print_out("# log_file=%s" % log_file, log_f) avg_step_time = 0.0 # Create the configurations for the sessions config_proto = utils.get_config_proto(log_device_placement=log_device_placement) # Create three sessions, one for each model train_sess = tf.Session(target=target_session, config=config_proto, graph=train_model.graph) eval_sess = tf.Session(target=target_session, config=config_proto, graph=eval_model.graph) infer_sess = tf.Session(target=target_session, config=config_proto, graph=infer_model.graph) # Load the train model from checkpoint or create a new one with train_model.graph.as_default(): loaded_train_model, global_step = model_helper.create_or_load_model(train_model.model, model_dir, train_sess, name="train") # Summary writer summary_writer = tf.summary.FileWriter( os.path.join(out_dir, summary_name), train_model.graph) # First evaluation run_full_eval( model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data) last_stats_step = global_step last_eval_step = global_step last_external_eval_step = global_step # This is the training loop. # Initialize the hyperparameters for the loop. step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0 checkpoint_total_count = 0.0 speed, train_ppl = 0.0, 0.0 start_train_time = time.time() utils.print_out( "# Start step %d, lr %g, %s" % (global_step, loaded_train_model.learning_rate.eval(session=train_sess), time.ctime()), log_f) # epoch_step records where we were within an epoch. Used to skip trained on examples skip_count = hparams.batch_size * hparams.epoch_step utils.print_out("# Init train iterator, skipping %d elements" % skip_count) # Initialize the training iterator train_sess.run( train_model.iterator.initializer, feed_dict={train_model.skip_count_placeholder: skip_count}) # Train until we reach num_steps. while global_step < num_train_steps: # Run a step start_step_time = time.time() try: step_result = loaded_train_model.train(train_sess) (_, step_loss, step_predict_count, step_summary, global_step, # The _ is the output of the update op step_word_count, batch_size) = step_result hparams.epoch_step += 1 except tf.errors.OutOfRangeError: # Finished going through the training dataset. Go to next epoch. hparams.epoch_step = 0 utils.print_out( "# Finished an epoch, step %d. Perform external evaluation" % global_step) # Decode and print a random sentence run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, sample_src_data, sample_tgt_data) # Perform external evaluation to save checkpoints if this is the best for some metric dev_scores, test_scores, _ = run_external_evaluation(infer_model, infer_sess, model_dir, hparams, summary_writer, save_on_best_dev=True) # Reinitialize the iterator from the beginning train_sess.run(train_model.iterator.initializer, feed_dict={train_model.skip_count_placeholder: 0}) continue # Write step summary. summary_writer.add_summary(step_summary, global_step) # update statistics step_time += (time.time() - start_step_time) checkpoint_loss += (step_loss * batch_size) checkpoint_predict_count += step_predict_count checkpoint_total_count += float(step_word_count) # Once in a while, we print statistics. if global_step - last_stats_step >= steps_per_stats: last_stats_step = global_step # Print statistics for the previous epoch. avg_step_time = step_time / steps_per_stats train_ppl = utils.safe_exp(checkpoint_loss / checkpoint_predict_count) speed = checkpoint_total_count / (1000 * step_time) utils.print_out( " global step %d lr %g " "step-time %.2fs wps %.2fK ppl %.2f %s" % (global_step, loaded_train_model.learning_rate.eval(session=train_sess), avg_step_time, speed, train_ppl, _get_best_results(hparams)), log_f) if math.isnan(train_ppl): # The model has screwed up break # Reset timer and loss. step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0 checkpoint_total_count = 0.0 if global_step - last_eval_step >= steps_per_eval: # Perform evaluation. Start by reassigning the last_eval_step variable to the current step last_eval_step = global_step # Print the progress and add summary utils.print_out("# Save eval, global step %d" % global_step) utils.add_summary(summary_writer, global_step, "train_ppl", train_ppl) # Save checkpoint loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "chatbot.ckpt"), global_step=global_step) # Decode and print a random sample run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, sample_src_data, sample_tgt_data) # Run internal evaluation, and update the ppl variables. The data iterator is instantieted in the method. dev_ppl, test_ppl = run_internal_eval(eval_model, eval_sess, model_dir, hparams, summary_writer) if global_step - last_external_eval_step >= steps_per_external_eval: # Run the external evaluation last_external_eval_step = global_step # Save checkpoint loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "chatbot.ckpt"), global_step=global_step) # Decode and print a random sample run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, sample_src_data, sample_tgt_data) # Run external evaluation, updating metric scores in the meanwhile. The unneeded output is the global step. dev_scores, test_scores, _ = run_external_evaluation(infer_model, infer_sess, model_dir, hparams, summary_writer, save_on_best_dev=True) # Done training. Save the model loaded_train_model.saver.save( train_sess, os.path.join(out_dir, "chatbot.ckpt"), global_step=global_step) result_summary, _, dev_scores, test_scores, dev_ppl, test_ppl = run_full_eval( model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data) utils.print_out( "# Final, step %d lr %g " "step-time %.2f wps %.2fK ppl %.2f, %s, %s" % (global_step, loaded_train_model.learning_rate.eval(session=train_sess), avg_step_time, speed, train_ppl, result_summary, time.ctime()), log_f) utils.print_time("# Done training!", start_train_time) utils.print_out("# Start evaluating saved best models.") for metric in hparams.metrics: best_model_dir = getattr(hparams, "best_" + metric + "_dir") result_summary, best_global_step, _, _, _, _ = run_full_eval( best_model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data) utils.print_out("# Best %s, step %d " "step-time %.2f wps %.2fK, %s, %s" % (metric, best_global_step, avg_step_time, speed, result_summary, time.ctime()), log_f) summary_writer.close() return (dev_scores, test_scores, dev_ppl, test_ppl, global_step)
def train(hparams, scope=None, target_session="", single_cell_fn=None): """Train a translation model.""" log_device_placement = hparams.log_device_placement out_dir = hparams.out_dir num_train_steps = hparams.num_train_steps steps_per_stats = hparams.steps_per_stats if hparams.eval_on_fly: steps_per_external_eval = hparams.steps_per_external_eval steps_per_eval = 10 * steps_per_stats if not steps_per_external_eval: steps_per_external_eval = 2 * steps_per_eval else: steps_per_snapshot = hparams.snapshot_interval if not hparams.attention: model_creator = nmt_model.Model elif hparams.attention_architecture == "standard": model_creator = attention_model.AttentionModel elif hparams.attention_architecture in ["gnmt", "gnmt_v2"]: model_creator = gnmt_model.GNMTModel else: raise ValueError("Unknown model architecture") train_model = create_train_model(model_creator, hparams, scope, single_cell_fn) if hparams.eval_on_fly: eval_model = create_eval_model(model_creator, hparams, scope, single_cell_fn) infer_model = inference.create_infer_model(model_creator, hparams, scope, single_cell_fn) # Preload data for sample decoding. dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src) dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt) sample_src_data = inference.load_data(dev_src_file) sample_tgt_data = inference.load_data(dev_tgt_file) summary_name = "train_log" model_dir = hparams.out_dir # Log and output files log_file = os.path.join(out_dir, "log_%d" % time.time()) log_f = tf.gfile.GFile(log_file, mode="a") utils.print_out("# log_file=%s" % log_file, log_f) avg_step_time = 0.0 # TensorFlow model config_proto = utils.get_config_proto( log_device_placement=log_device_placement) train_sess = tf.Session(target=target_session, config=config_proto, graph=train_model.graph) if hparams.eval_on_fly: eval_sess = tf.Session(target=target_session, config=config_proto, graph=eval_model.graph) infer_sess = tf.Session(target=target_session, config=config_proto, graph=infer_model.graph) with train_model.graph.as_default(): model_helper.initialize_cnn(train_model.model, train_sess) loaded_train_model, global_step = model_helper.create_or_load_model( train_model.model, model_dir, train_sess, "train") # Summary writer summary_writer = tf.summary.FileWriter(os.path.join(out_dir, summary_name), train_model.graph) # First evaluation if hparams.eval_on_fly: run_full_eval(model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data) last_stats_step = global_step if hparams.eval_on_fly: last_eval_step = global_step last_external_eval_step = global_step else: last_snapshot_step = global_step # This is the training loop. step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0 checkpoint_total_count = 0.0 speed, train_ppl = 0.0, 0.0 start_train_time = time.time() utils.print_out( "# Start step %d, lr %g, %s" % (global_step, loaded_train_model.learning_rate.eval( session=train_sess), time.ctime()), log_f) # Initialize all of the iterators skip_count = hparams.batch_size * hparams.epoch_step utils.print_out("# Init train iterator, skipping %d elements" % skip_count) train_sess.run(train_model.iterator.initializer, feed_dict={train_model.skip_count_placeholder: skip_count}) while global_step < num_train_steps: ### Run a step ### start_time = time.time() try: step_result = loaded_train_model.train(train_sess) (_, step_loss, step_predict_count, step_summary, global_step, step_word_count, batch_size) = step_result hparams.epoch_step += 1 except tf.errors.OutOfRangeError: # Finished going through the training dataset. Go to next epoch. hparams.epoch_step = 0 utils.print_out( "# Finished an epoch, step %d. Perform external evaluation" % global_step) if hparams.eval_on_fly: run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, sample_src_data, sample_tgt_data) dev_scores, test_scores, _ = run_external_eval( infer_model, infer_sess, model_dir, hparams, summary_writer) train_sess.run(train_model.iterator.initializer, feed_dict={train_model.skip_count_placeholder: 0}) continue # Write step summary. summary_writer.add_summary(step_summary, global_step) # update statistics step_time += (time.time() - start_time) checkpoint_loss += (step_loss * batch_size) checkpoint_predict_count += step_predict_count checkpoint_total_count += float(step_word_count) # Once in a while, we print statistics. if global_step - last_stats_step >= steps_per_stats: last_stats_step = global_step # Print statistics for the previous epoch. avg_step_time = step_time / steps_per_stats train_ppl = utils.safe_exp(checkpoint_loss / checkpoint_predict_count) speed = checkpoint_total_count / (1000 * step_time) utils.print_out( " global step %d lr %g step-time %.2fs wps %.2fK ppl %.2f %s" % (global_step, loaded_train_model.learning_rate.eval(session=train_sess), avg_step_time, speed, train_ppl, _get_best_results(hparams)), log_f) if math.isnan(train_ppl): break # Reset timer and loss. step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0 checkpoint_total_count = 0.0 ## if (not hparams.eval_on_fly) and (global_step - last_snapshot_step >= steps_per_snapshot): last_snapshot_step = global_step utils.print_out("# Cihan: Saving Snapshot, global step %d" % global_step) utils.add_summary(summary_writer, global_step, "train_ppl", train_ppl) # Save checkpoint loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "translate.ckpt"), global_step=global_step) if hparams.eval_on_fly and (global_step - last_eval_step >= steps_per_eval): last_eval_step = global_step utils.print_out("# Save eval, global step %d" % global_step) utils.add_summary(summary_writer, global_step, "train_ppl", train_ppl) # Save checkpoint loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "translate.ckpt"), global_step=global_step) # Evaluate on dev/test run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, sample_src_data, sample_tgt_data) dev_ppl, test_ppl = run_internal_eval(eval_model, eval_sess, model_dir, hparams, summary_writer) if hparams.eval_on_fly and (global_step - last_external_eval_step >= steps_per_external_eval): last_external_eval_step = global_step # Save checkpoint loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "translate.ckpt"), global_step=global_step) run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, sample_src_data, sample_tgt_data) dev_scores, test_scores, _ = run_external_eval( infer_model, infer_sess, model_dir, hparams, summary_writer) # Done training loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "translate.ckpt"), global_step=global_step) if hparams.eval_on_fly: result_summary, _, dev_scores, test_scores, dev_ppl, test_ppl = run_full_eval( model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data) utils.print_out( "# Final, step %d lr %g step-time %.2f wps %.2fK ppl %.2f, %s, %s" % (global_step, loaded_train_model.learning_rate.eval(session=train_sess), avg_step_time, speed, train_ppl, result_summary, time.ctime()), log_f) utils.print_time("# Done training!", start_train_time) utils.print_out("# Start evaluating saved best models.") for metric in hparams.metrics: best_model_dir = getattr(hparams, "best_" + metric + "_dir") result_summary, best_global_step, _, _, _, _ = run_full_eval( best_model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data) utils.print_out( "# Best %s, step %d step-time %.2f wps %.2fK, %s, %s" % (metric, best_global_step, avg_step_time, speed, result_summary, time.ctime()), log_f) summary_writer.close() if hparams.eval_on_fly: return dev_scores, test_scores, dev_ppl, test_ppl, global_step else: return global_step
def train(hparams): train_model = mc.create_train_model(hparams) eval_model = mc.create_eval_model(hparams) infer_model = mc.create_infer_model(hparams) # TensorFlow model config_proto = utils.get_config_proto( log_device_placement=hparams.log_device_placement, allow_soft_placement=True) train_sess = tf.Session(graph=train_model.graph, config=config_proto) eval_sess = tf.Session(graph=eval_model.graph, config=config_proto) infer_sess = tf.Session(graph=infer_model.graph, config=config_proto) with train_model.graph.as_default(): train_model, global_step = model_helper.create_or_load_model(train_model, hparams.out_dir, train_sess, "train") train_sess.run(train_model.iterator.initializer) step_in_epoch = 0 current_epoch = 0 chkpt_predicted_count, chkpt_train_loss = 0.0, 0.0 start_time = time.time() while current_epoch < hparams.epoch: try: _, \ train_loss, \ step_predicted_count, \ source, \ target_input, \ target_output, \ logits, \ final_context_state, \ sample_id, \ global_step = train_model.model.train(train_sess) step_in_epoch += 1 except tf.errors.OutOfRangeError: step_in_epoch = 0 print ("epoch %s finished" % str(current_epoch)) infer_util.run_infer(hparams, infer_sess, infer_model, current_epoch, ["bleu", "rouge", "accuracy"]) current_epoch = current_epoch + 1 train_sess.run(train_model.iterator.initializer) continue # train_loss = step_results[0] # step_predicted_count = step_results[1] chkpt_predicted_count += step_predicted_count chkpt_train_loss += (train_loss * hparams.batch_size) if step_in_epoch % hparams.steps_per_eval == 0 and step_in_epoch > 0: # if hparams.time_major: # source = np.transpose(source) # target_input = np.transpose(target_input) # target_output = np.transpose(target_output) print ("global step: ", str(global_step)) # print (source[0]) # print (target_input[0]) # print (target_output[0]) # print (np.shape(logits)) # print (final_context_state) # print (sample_id) # user_input = input("please input") # print (user_input) train_ppl = utils.safe_exp(chkpt_train_loss / chkpt_predicted_count) print ("epoch %d: step_in_epoch %d, train ppl %.2f, avg loss %.2f, lr %.7f, time %.2fs" % (current_epoch, step_in_epoch, train_ppl, chkpt_train_loss/(hparams.steps_per_eval * hparams.batch_size), train_model.model.learning_rate.eval(session=train_sess), time.time() - start_time)) start_time = time.time() chkpt_predicted_count, chkpt_train_loss = 0.0, 0.0 train_model.saver.save(train_sess, os.path.join(hparams.out_dir, "translate.ckpt"), global_step=current_epoch) eval_ppl, eval_loss = eval_util.run_eval(hparams, eval_sess, eval_model, current_epoch, step_in_epoch, )