Esempio n. 1
0
def train(hparams, scope=None, target_session="", compute_ppl=0):
    """Train a translation model."""
    log_device_placement = hparams.log_device_placement
    out_dir = hparams.out_dir
    num_train_steps = hparams.num_train_steps
    steps_per_stats = hparams.steps_per_stats
    steps_per_external_eval = hparams.steps_per_external_eval
    steps_per_eval = 10 * steps_per_stats
    avg_ckpts = hparams.avg_ckpts

    if not steps_per_external_eval:
        steps_per_external_eval = 5 * steps_per_eval

    if not hparams.attention:  # choose this model
        model_creator = nmt_model.Model
    else:  # Attention
        if (hparams.encoder_type == "gnmt"
                or hparams.attention_architecture in ["gnmt", "gnmt_v2"]):
            model_creator = gnmt_model.GNMTModel
        elif hparams.attention_architecture == "standard":
            model_creator = attention_model.AttentionModel
        else:
            raise ValueError("Unknown attention architecture %s" %
                             hparams.attention_architecture)

    train_model = model_helper.create_train_model(model_creator, hparams,
                                                  scope)
    eval_model = model_helper.create_eval_model(model_creator, hparams, scope)
    infer_model = model_helper.create_infer_model(model_creator, hparams,
                                                  scope)

    # Preload data for sample decoding.
    dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src)
    dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt)
    sample_src_data = inference.load_data(dev_src_file)
    sample_tgt_data = inference.load_data(dev_tgt_file)
    wsd_src_file = "%s" % (hparams.sample_prefix)

    wsd_src_data = inference.load_data(wsd_src_file)

    summary_name = "train_log"
    model_dir = hparams.out_dir

    # Log and output files
    log_file = os.path.join(out_dir, "log_%d" % time.time())
    log_f = tf.gfile.GFile(log_file, mode="a")
    utils.print_out("# log_file=%s" % log_file, log_f)

    # TensorFlow model
    config_proto = utils.get_config_proto(
        log_device_placement=log_device_placement,
        num_intra_threads=hparams.num_intra_threads,
        num_inter_threads=hparams.num_inter_threads)
    train_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=train_model.graph)
    eval_sess = tf.Session(target=target_session,
                           config=config_proto,
                           graph=eval_model.graph)
    infer_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=infer_model.graph)

    with train_model.graph.as_default():
        loaded_train_model, global_step = model_helper.create_or_load_model(
            train_model.model, model_dir, train_sess, "train")

    # Summary writer
    summary_writer = tf.summary.FileWriter(os.path.join(out_dir, summary_name),
                                           train_model.graph)

    # First evaluation
    '''
  run_full_eval(
      model_dir, infer_model, infer_sess,
      eval_model, eval_sess, hparams,
      summary_writer, sample_src_data,
      sample_tgt_data, avg_ckpts)
  '''
    last_stats_step = global_step
    last_eval_step = global_step
    last_external_eval_step = global_step

    # This is the training loop.
    stats, info, start_train_time = before_train(loaded_train_model,
                                                 train_model, train_sess,
                                                 global_step, hparams, log_f)
    end_step = global_step + 100
    while global_step < end_step:  # num_train_steps
        ### Run a step ###
        start_time = time.time()
        try:
            # then forward inference result to WSD, get reward
            step_result = loaded_train_model.train(train_sess)
            # forward reward to placeholder of loaded_train_model, and write a new train function where loss = loss*reward
            hparams.epoch_step += 1
        except tf.errors.OutOfRangeError:
            # Finished going through the training dataset.  Go to next epoch.
            hparams.epoch_step = 0
            utils.print_out(
                "# Finished an epoch, step %d. Perform external evaluation" %
                global_step)

            # run_sample_decode(infer_model, infer_sess, model_dir, hparams,
            #                   summary_writer, sample_src_data, sample_tgt_data)

            # only for pretrain
            # run_external_eval(infer_model, infer_sess, model_dir, hparams,
            #                   summary_writer)

            if avg_ckpts:
                run_avg_external_eval(infer_model, infer_sess, model_dir,
                                      hparams, summary_writer, global_step)

            train_sess.run(train_model.iterator.initializer,
                           feed_dict={train_model.skip_count_placeholder: 0})

            continue

        # Process step_result, accumulate stats, and write summary
        global_step, info["learning_rate"], step_summary = update_stats(
            stats, start_time, step_result, hparams)
        summary_writer.add_summary(step_summary, global_step)
        if compute_ppl:
            run_internal_eval(eval_model, eval_sess, model_dir, hparams,
                              summary_writer)
        # Once in a while, we print statistics.
        if global_step - last_stats_step >= steps_per_stats:
            last_stats_step = global_step
            is_overflow = process_stats(stats, info, global_step,
                                        steps_per_stats, log_f)
            print_step_info("  ", global_step, info,
                            _get_best_results(hparams), log_f)
            if is_overflow:
                break

            # Reset statistics
            stats = init_stats()

        if global_step - last_eval_step >= steps_per_eval:
            last_eval_step = global_step
            utils.print_out("# Save eval, global step %d" % global_step)
            utils.add_summary(summary_writer, global_step, "train_ppl",
                              info["train_ppl"])

            # Save checkpoint
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(out_dir,
                                                       "translate.ckpt"),
                                          global_step=global_step)

            # Evaluate on dev/test
            run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                              summary_writer, sample_src_data, sample_tgt_data)
            run_internal_eval(eval_model, eval_sess, model_dir, hparams,
                              summary_writer)

        if global_step - last_external_eval_step >= steps_per_external_eval:
            last_external_eval_step = global_step

            # Save checkpoint
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(out_dir,
                                                       "translate.ckpt"),
                                          global_step=global_step)
            run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                              summary_writer, sample_src_data, sample_tgt_data)
            run_external_eval(infer_model, infer_sess, model_dir, hparams,
                              summary_writer)

            if avg_ckpts:
                run_avg_external_eval(infer_model, infer_sess, model_dir,
                                      hparams, summary_writer, global_step)

    # Done training
    loaded_train_model.saver.save(train_sess,
                                  os.path.join(out_dir, "translate.ckpt"),
                                  global_step=global_step)
    '''
Esempio n. 2
0
def train(hparams):
    """Train a seq2seq model."""
    log_device_placement = hparams.log_device_placement
    out_dir = hparams.out_dir
    num_train_steps = hparams.num_train_steps
    steps_per_stats = hparams.steps_per_stats

    model_creator = model.Model

    train_model = model_helper.create_train_model(model_creator, hparams)

    summary_name = "train_log"
    model_dir = hparams.out_dir

    # Log and output files
    log_file = os.path.join(out_dir, "log_%d" % time.time())
    log_f = tf.gfile.GFile(log_file, mode="a")
    utils.print_out("# log_files=%s" % log_file, log_f)

    # TensorFlow model
    config_proto = utils.get_config_proto(
        log_device_placement=log_device_placement,
        num_intra_threads=hparams.num_intra_threads,
        num_inter_threads=hparams.num_inter_threads)
    train_sess = tf.Session(config=config_proto, graph=train_model.graph)

    with train_model.graph.as_default():
        loaded_train_model, global_step = model_helper.create_or_load_model(
            train_model.model, model_dir, train_sess, "train")

    # Summary writer
    summary_writer = tf.summary.FileWriter(
        os.path.join(out_dir, summary_name), train_model.graph)

    last_stats_step = global_step

    # This is the training loop.
    stats, info, start_train_time = before_train(
        loaded_train_model, train_model, train_sess, global_step, hparams, log_f)
    while global_step < num_train_steps:
        ### Run a step ###
        start_time = time.time()
        try:
            step_result = loaded_train_model.train(train_sess)
            hparams.epoch_step += 1
        except tf.errors.OutOfRangeError:
            # Finished going through the training dataset.  Go to next epoch.
            hparams.epoch_step = 0
            utils.print_out("# Finished an epoch, step %d." % global_step)

            train_sess.run(
                train_model.iterator.initializer,
                feed_dict={train_model.skip_count_placeholder: 0})
            continue

        # Process step_result, accumulate stats, and write summary
        global_step, info["learning_rate"], step_summary = update_stats(
            stats, start_time, step_result)
        summary_writer.add_summary(step_summary, global_step)

        # Once in a while, we print statistics.
        if global_step - last_stats_step >= steps_per_stats:
            last_stats_step = global_step
            is_overflow = process_stats(stats, info, global_step, steps_per_stats, log_f)
            print_step_info("  ", global_step, info, log_f)

            if is_overflow:
                break

            # Reset statistics
            stats = init_stats()

    # Done training
    loaded_train_model.saver.save(
        train_sess,
        os.path.join(out_dir, "seq2seq.ckpt"),
        global_step=global_step)

    summary_writer.close()

    return global_step
Esempio n. 3
0
def train(hparams, identity, scope=None, target_session=""):
  """main loop to train the dialogue model. identity is used."""
  out_dir = hparams.out_dir
  steps_per_stats = hparams.steps_per_stats
  steps_per_internal_eval = 3 * steps_per_stats

  model_creator = diag_model.Model

  train_model = model_helper.create_train_model(model_creator, hparams, scope)

  model_dir = hparams.out_dir

  # Log and output files
  log_file = os.path.join(out_dir, identity+"log_%d" % time.time())
  log_f = tf.gfile.GFile(log_file, mode="a")
  utils.print_out("# log_file=%s" % log_file, log_f)

  avg_step_time = 0.0

  # load TensorFlow session and model
  config_proto = utils.get_config_proto(
      log_device_placement=hparams.log_device_placement,
      allow_soft_placement=True)

  train_sess = tf.Session(
      target=target_session, config=config_proto, graph=train_model.graph)

  train_handle = train_sess.run(train_model.train_iterator.string_handle())

  with train_model.graph.as_default():
    loaded_train_model, global_step = model_helper.create_or_load_model(
        train_model.model, model_dir, train_sess, "train")

  # initialize summary writer
  summary_writer = tf.summary.FileWriter(
      os.path.join(out_dir, "train_log"), train_model.graph)

  last_stats_step = global_step
  last_eval_step = global_step

  # initialize training stats.
  step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0
  checkpoint_total_count = 0.0
  speed, train_ppl = 0.0, 0.0
  start_train_time = time.time()

  utils.print_out(
      "# Start step %d, lr %g, %s" %
      (global_step, loaded_train_model.learning_rate.eval(session=train_sess),
       time.ctime()),
      log_f)

  # initialize iterators
  skip_count = hparams.batch_size * hparams.epoch_step
  utils.print_out("# Init train iterator, skipping %d elements" % skip_count)
  train_sess.run(
      train_model.train_iterator.initializer,
      feed_dict={train_model.skip_count_placeholder: skip_count})

  # main training loop
  while global_step < hparams.num_train_steps:
    start_time = time.time()
    try:  #  run a step
      step_result = loaded_train_model.train(train_sess, train_handle)
      (_, step_loss, all_summaries, step_predict_count, step_summary,
       global_step, step_word_count, batch_size, _, _, words1, words2, mask1,
       mask2) = step_result
      hparams.epoch_step += 1

    except tf.errors.OutOfRangeError:  # finished an epoch
      hparams.epoch_step = 0
      utils.print_out("# Finished an epoch, step %d." % global_step)
      train_sess.run(
          train_model.train_iterator.initializer,
          feed_dict={train_model.skip_count_placeholder: 0})
      continue

    # Write step summary.
    summary_writer.add_summary(step_summary, global_step)
    for key in all_summaries:
      utils.add_summary(summary_writer, global_step, key, all_summaries[key])

    # update statistics
    step_time += (time.time() - start_time)

    checkpoint_loss += (step_loss * batch_size)
    checkpoint_predict_count += step_predict_count
    checkpoint_total_count += float(step_word_count)

    if global_step - last_stats_step >= steps_per_stats:
      # print statistics for the previous epoch and save the model.
      last_stats_step = global_step

      avg_step_time = step_time / steps_per_stats
      utils.add_summary(summary_writer, global_step, "step_time", avg_step_time)
      train_ppl = utils.safe_exp(checkpoint_loss / checkpoint_predict_count)
      speed = checkpoint_total_count / (1000 * step_time)
      if math.isnan(train_ppl):
        break

      # Reset timer and loss.
      step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0
      checkpoint_total_count = 0.0

      # save the model
      loaded_train_model.saver.save(
          train_sess,
          os.path.join(out_dir, "dialogue.ckpt"),
          global_step=global_step)

      # print the dialogue if in debug mode
      if hparams.debug:
        utils.print_current_dialogue(words1, words2, mask1, mask2)

    #  write out internal evaluation
    if global_step - last_eval_step >= steps_per_internal_eval:
      last_eval_step = global_step

      utils.print_out("# Internal Evaluation. global step %d" % global_step)
      utils.add_summary(summary_writer, global_step, "train_ppl", train_ppl)

  # finished training
  loaded_train_model.saver.save(
      train_sess,
      os.path.join(out_dir, "dialogue.ckpt"),
      global_step=global_step)
  result_summary = ""
  utils.print_out(
      "# Final, step %d lr %g "
      "step-time %.2f wps %.2fK ppl %.2f, %s, %s" %
      (global_step, loaded_train_model.learning_rate.eval(session=train_sess),
       avg_step_time, speed, train_ppl, result_summary, time.ctime()),
      log_f)
  utils.print_time("# Done training!", start_train_time)
  utils.print_out("# Start evaluating saved best models.")
  summary_writer.close()
Esempio n. 4
0
def train(hparams, scope=None, target_session=""):
    """Train a translation model."""
    log_device_placement = hparams.log_device_placement
    out_dir = hparams.out_dir
    num_train_steps = hparams.num_train_steps
    steps_per_stats = hparams.steps_per_stats
    steps_per_external_eval = hparams.steps_per_external_eval
    steps_per_eval = 10 * steps_per_stats
    avg_ckpts = hparams.avg_ckpts

    if not steps_per_external_eval:
        steps_per_external_eval = 5 * steps_per_eval

    # Create model
    model_creator = get_model_creator(hparams)
    train_model = model_helper.create_train_model(model_creator, hparams,
                                                  scope)
    eval_model = model_helper.create_eval_model(model_creator, hparams, scope)
    infer_model = model_helper.create_infer_model(model_creator, hparams,
                                                  scope)

    # Preload data for sample decoding.
    dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src)
    dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt)
    sample_src_data = inference.load_data(dev_src_file)
    sample_tgt_data = inference.load_data(dev_tgt_file)

    summary_name = "train_log"
    model_dir = hparams.out_dir

    # Log and output files
    log_file = os.path.join(out_dir, "log_%d" % time.time())
    log_f = tf.gfile.GFile(log_file, mode="a")
    utils.print_out("# log_file=%s" % log_file, log_f)

    # TensorFlow model
    config_proto = utils.get_config_proto(
        log_device_placement=log_device_placement,
        num_intra_threads=hparams.num_intra_threads,
        num_inter_threads=hparams.num_inter_threads)
    train_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=train_model.graph)
    eval_sess = tf.Session(target=target_session,
                           config=config_proto,
                           graph=eval_model.graph)
    infer_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=infer_model.graph)

    with train_model.graph.as_default():
        loaded_train_model, global_step = model_helper.create_or_load_model(
            train_model.model, model_dir, train_sess, "train")

    # Summary writer
    summary_writer = tf.summary.FileWriter(os.path.join(out_dir, summary_name),
                                           train_model.graph)

    # First evaluation
    run_full_eval(model_dir, infer_model, infer_sess, eval_model, eval_sess,
                  hparams, summary_writer, sample_src_data, sample_tgt_data,
                  avg_ckpts)

    last_stats_step = global_step
    last_eval_step = global_step
    last_external_eval_step = global_step

    # This is the training loop.
    stats, info, start_train_time = before_train(loaded_train_model,
                                                 train_model, train_sess,
                                                 global_step, hparams, log_f)
    while global_step < num_train_steps:
        ### Run a step ###
        start_time = time.time()
        try:
            step_result = loaded_train_model.train(train_sess)
            hparams.epoch_step += 1
        except tf.errors.OutOfRangeError:
            # Finished going through the training dataset.  Go to next epoch.
            hparams.epoch_step = 0
            utils.print_out(
                "# Finished an epoch, step %d. Perform external evaluation" %
                global_step)
            run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                              summary_writer, sample_src_data, sample_tgt_data)
            run_external_eval(infer_model, infer_sess, model_dir, hparams,
                              summary_writer)

            if avg_ckpts:
                run_avg_external_eval(infer_model, infer_sess, model_dir,
                                      hparams, summary_writer, global_step)

            train_sess.run(train_model.iterator.initializer,
                           feed_dict={train_model.skip_count_placeholder: 0})
            continue

        # Process step_result, accumulate stats, and write summary
        global_step, info["learning_rate"], step_summary = update_stats(
            stats, start_time, step_result)
        summary_writer.add_summary(step_summary, global_step)

        # Once in a while, we print statistics.
        if global_step - last_stats_step >= steps_per_stats:
            last_stats_step = global_step
            is_overflow = process_stats(stats, info, global_step,
                                        steps_per_stats, log_f)
            print_step_info("  ", global_step, info, get_best_results(hparams),
                            log_f)
            if is_overflow:
                break

            # Reset statistics
            stats = init_stats()

        if global_step - last_eval_step >= steps_per_eval:
            last_eval_step = global_step
            utils.print_out("# Save eval, global step %d" % global_step)
            add_info_summaries(summary_writer, global_step, info)

            # Save checkpoint
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(out_dir,
                                                       "translate.ckpt"),
                                          global_step=global_step)

            # Evaluate on dev/test
            run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                              summary_writer, sample_src_data, sample_tgt_data)
            run_internal_eval(eval_model, eval_sess, model_dir, hparams,
                              summary_writer)

        if global_step - last_external_eval_step >= steps_per_external_eval:
            last_external_eval_step = global_step

            # Save checkpoint
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(out_dir,
                                                       "translate.ckpt"),
                                          global_step=global_step)
            run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                              summary_writer, sample_src_data, sample_tgt_data)
            run_external_eval(infer_model, infer_sess, model_dir, hparams,
                              summary_writer)

            if avg_ckpts:
                run_avg_external_eval(infer_model, infer_sess, model_dir,
                                      hparams, summary_writer, global_step)

    # Done training
    loaded_train_model.saver.save(train_sess,
                                  os.path.join(out_dir, "translate.ckpt"),
                                  global_step=global_step)

    (result_summary, _, final_eval_metrics) = (run_full_eval(
        model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams,
        summary_writer, sample_src_data, sample_tgt_data, avg_ckpts))
    print_step_info("# Final, ", global_step, info, result_summary, log_f)
    utils.print_time("# Done training!", start_train_time)

    summary_writer.close()

    utils.print_out("# Start evaluating saved best models.")
    for metric in hparams.metrics:
        best_model_dir = getattr(hparams, "best_" + metric + "_dir")
        summary_writer = tf.summary.FileWriter(
            os.path.join(best_model_dir, summary_name), infer_model.graph)
        result_summary, best_global_step, _ = run_full_eval(
            best_model_dir, infer_model, infer_sess, eval_model, eval_sess,
            hparams, summary_writer, sample_src_data, sample_tgt_data)
        print_step_info("# Best %s, " % metric, best_global_step, info,
                        result_summary, log_f)
        summary_writer.close()

        if avg_ckpts:
            best_model_dir = getattr(hparams, "avg_best_" + metric + "_dir")
            summary_writer = tf.summary.FileWriter(
                os.path.join(best_model_dir, summary_name), infer_model.graph)
            result_summary, best_global_step, _ = run_full_eval(
                best_model_dir, infer_model, infer_sess, eval_model, eval_sess,
                hparams, summary_writer, sample_src_data, sample_tgt_data)
            print_step_info("# Averaged Best %s, " % metric, best_global_step,
                            info, result_summary, log_f)
            summary_writer.close()

    return final_eval_metrics, global_step
Esempio n. 5
0
def train(hps, scope=None, target_session=""):
    """Train a translation model."""
    log_device_placement = hps.log_device_placement
    out_dir = hps.out_dir
    num_train_steps = hps.num_train_steps
    steps_per_stats = hps.steps_per_stats
    steps_per_external_eval = hps.steps_per_external_eval
    steps_per_eval = 100 * steps_per_stats
    if not steps_per_external_eval:
        steps_per_external_eval = 5 * steps_per_eval

    if hps.attention_architecture == "baseline":
        model_creator = AttentionModel
    else:
        model_creator = AttentionHistoryModel

    train_model = model_helper.create_train_model(model_creator, hps, scope)
    eval_model = model_helper.create_eval_model(model_creator, hps, scope)
    infer_model = model_helper.create_infer_model(model_creator, hps, scope)

    # Preload data for sample decoding.

    article_filenames = []
    abstract_filenames = []
    art_dir = hps.data_dir + '/article'
    abs_dir = hps.data_dir + '/abstract'
    for file in os.listdir(art_dir):
        if file.startswith(hps.dev_prefix):
            article_filenames.append(art_dir + "/" + file)
    for file in os.listdir(abs_dir):
        if file.startswith(hps.dev_prefix):
            abstract_filenames.append(abs_dir + "/" + file)
    # if random_decode:
    #     """if this is a random sampling process during training"""
    decode_id = random.randint(0, len(article_filenames) - 1)
    single_article_file = article_filenames[decode_id]
    single_abstract_file = abstract_filenames[decode_id]

    dev_src_file = single_article_file
    dev_tgt_file = single_abstract_file
    sample_src_data = inference_base_model.load_data(dev_src_file)
    sample_tgt_data = inference_base_model.load_data(dev_tgt_file)

    summary_name = "train_log"
    model_dir = hps.out_dir

    # Log and output files
    log_file = os.path.join(out_dir, "log_%d" % time.time())
    log_f = tf.gfile.GFile(log_file, mode="a")
    utils.print_out("# log_file=%s" % log_file, log_f)

    avg_step_time = 0.0

    # TensorFlow model
    config_proto = utils.get_config_proto(
        log_device_placement=log_device_placement)

    train_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=train_model.graph)
    eval_sess = tf.Session(target=target_session,
                           config=config_proto,
                           graph=eval_model.graph)
    infer_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=infer_model.graph)

    with train_model.graph.as_default():
        loaded_train_model, global_step = model_helper.create_or_load_model(
            train_model.model, model_dir, train_sess, "train")

    # Summary writer
    summary_writer = tf.summary.FileWriter(os.path.join(out_dir, summary_name),
                                           train_model.graph)

    # First evaluation
    # run_full_eval(
    #     model_dir, infer_model, infer_sess,
    #     eval_model, eval_sess, hps,
    #     summary_writer,sample_src_data,sample_tgt_data)

    last_stats_step = global_step
    last_eval_step = global_step
    last_external_eval_step = global_step

    # This is the training loop.
    stats = init_stats()
    speed, train_ppl = 0.0, 0.0
    start_train_time = time.time()

    utils.print_out(
        "# Start step %d, lr %g, %s" %
        (global_step, loaded_train_model.learning_rate.eval(
            session=train_sess), time.ctime()), log_f)

    # Initialize all of the iterators
    skip_count = hps.batch_size * hps.epoch_step
    utils.print_out("# Init train iterator, skipping %d elements" % skip_count)
    train_sess.run(train_model.iterator.initializer,
                   feed_dict={train_model.skip_count_placeholder: skip_count})
    epoch_step = 0
    while global_step < num_train_steps:
        ### Run a step ###
        start_time = time.time()
        try:
            step_result = loaded_train_model.train(train_sess)
            epoch_step += 1
        except tf.errors.OutOfRangeError:
            # Finished going through the training dataset.  Go to next epoch.
            epoch_step = 0
            utils.print_out(
                "# Finished an epoch, step %d. Perform external evaluation" %
                global_step)

            run_sample_decode(infer_model, infer_sess, model_dir, hps,
                              summary_writer, sample_src_data, sample_tgt_data)
            dev_scores, test_scores, _ = run_external_eval(
                infer_model, infer_sess, model_dir, hps, summary_writer)
            train_sess.run(train_model.iterator.initializer,
                           feed_dict={train_model.skip_count_placeholder: 0})
            continue

        # Write step summary and accumulate statistics
        global_step = update_stats(stats, summary_writer, start_time,
                                   step_result)

        # Once in a while, we print statistics.
        if global_step - last_stats_step >= steps_per_stats:
            last_stats_step = global_step
            is_overflow = check_stats(stats, global_step, steps_per_stats, hps,
                                      log_f)
            if is_overflow:
                break

            # Reset statistics
            stats = init_stats()

        if global_step - last_eval_step >= steps_per_eval:
            last_eval_step = global_step

            utils.print_out("# Save eval, global step %d" % global_step)
            utils.add_summary(summary_writer, global_step, "train_ppl",
                              train_ppl)

            # Save checkpoint
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(out_dir,
                                                       "summarized.ckpt"),
                                          global_step=global_step)

            # Evaluate on dev/test
            run_sample_decode(infer_model, infer_sess, model_dir, hps,
                              summary_writer, sample_src_data, sample_tgt_data)
            dev_ppl, test_ppl = run_internal_eval(eval_model, eval_sess,
                                                  model_dir, hps,
                                                  summary_writer)

        if global_step - last_external_eval_step >= steps_per_external_eval:
            last_external_eval_step = global_step

            # Save checkpoint
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(out_dir,
                                                       "summarized.ckpt"),
                                          global_step=global_step)
            run_sample_decode(infer_model, infer_sess, model_dir, hps,
                              summary_writer, sample_src_data, sample_tgt_data)
            dev_scores, test_scores, _ = run_external_eval(
                infer_model, infer_sess, model_dir, hps, summary_writer)

    # Done training
    loaded_train_model.saver.save(train_sess,
                                  os.path.join(out_dir, "summarized.ckpt"),
                                  global_step=global_step)

    result_summary, _, dev_scores, test_scores, dev_ppl, test_ppl = run_full_eval(
        model_dir, infer_model, infer_sess, eval_model, eval_sess, hps,
        summary_writer, sample_src_data, sample_tgt_data)
    utils.print_out(
        "# Final, step %d lr %g "
        "step-time %.2f wps %.2fK ppl %.2f, %s, %s" %
        (global_step,
         loaded_train_model.learning_rate.eval(session=train_sess),
         avg_step_time, speed, train_ppl, result_summary, time.ctime()), log_f)
    utils.print_time("# Done training!", start_train_time)

    summary_writer.close()

    utils.print_out("# Start evaluating saved best models.")
    for metric in hps.metrics:
        best_model_dir = getattr(hps, "best_" + metric + "_dir")
        summary_writer = tf.summary.FileWriter(
            os.path.join(best_model_dir, summary_name), infer_model.graph)
        result_summary, best_global_step, _, _, _, _ = run_full_eval(
            best_model_dir, infer_model, infer_sess, eval_model, eval_sess,
            hps, summary_writer, sample_src_data, sample_tgt_data)
        utils.print_out(
            "# Best %s, step %d "
            "step-time %.2f wps %.2fK, %s, %s" %
            (metric, best_global_step, avg_step_time, speed, result_summary,
             time.ctime()), log_f)
        summary_writer.close()

    return (dev_scores, test_scores, dev_ppl, test_ppl, global_step)
Esempio n. 6
0
def train(flags):
    """Train the policy gradient model. """
    
    out_dir = flags.out_dir
    num_train_steps = flags.num_train_steps
    steps_per_infer = flags.steps_per_infer
    
    # Create model for train, infer mode
    model_creator = get_model_creator(flags)
    train_model = model_helper.create_train_model(flags, model_creator)
    infer_model = model_helper.create_infer_model(flags, model_creator)

    # TODO. set for distributed training and multi gpu 
    config_proto = tf.ConfigProto(allow_soft_placement=True)
    config_proto.gpu_options.allow_growth = True
    
    # Session for train, infer
    train_sess = tf.Session(
        config=config_proto, graph=train_model.graph)
    infer_sess = tf.Session(
        config=config_proto, graph=infer_model.graph)
    
    # Load the train model if there's the file in the directory
    # otherwise, initialize vars in the train model
    with train_model.graph.as_default():
        loaded_train_model, global_step = model_helper.create_or_load_model(
            out_dir, train_model.model, train_sess) 
    
    # Summary
    train_summary = "train_log"
    infer_summary = "infer_log"

    # Summary writer for train, infer
    train_summary_writer = tf.summary.FileWriter(
        os.path.join(out_dir, train_summary), train_model.graph)
    infer_summary_writer = tf.summary.FileWriter(
        os.path.join(out_dir, infer_summary))
    
    # First evaluation
    run_infer(infer_model, out_dir, infer_sess)

    # Initialize step var
    last_infer_steps = global_step
    
    # Training loop
    while global_step < num_train_steps:
        output_tuple = loaded_train_model.train(train_sess)
        global_step = output_tuple.global_step
        train_summary = output_tuple.train_summary    
    
        # Update train summary
        train_summary_writer.add_summary(train_summary, global_step)
        print('current global_step: {}'.format(global_step))

        # Evaluate the model for steps_per_infer 
        if global_step - last_infer_steps >= steps_per_infer:
            # Save checkpoint
            loaded_train_model.saver.save(train_sess, 
                os.path.join(out_dir, "rl.ckpt"), global_step)
            
            last_infer_steps = global_step
            output_tuple = run_infer(infer_model, out_dir, infer_sess)  
            infer_summary = output_tuple.infer_summary

            # Update infer summary
            infer_summary_writer.add_summary(infer_summary, global_step)
    
    # Done training
    loaded_train_model.saver.save(train_sess, 
        os.path.join(out_dir, "rl.ckpt"), global_step)
    print('Train done')
Esempio n. 7
0
def train(hparams, scope=None, target_session=""):
    """Train a translation model."""
    log_device_placement = hparams.log_device_placement
    out_dir = hparams.out_dir
    num_train_steps = hparams.num_train_steps
    steps_per_stats = hparams.steps_per_stats
    steps_per_external_eval = hparams.steps_per_external_eval
    steps_per_eval = 10 * steps_per_stats
    if not steps_per_external_eval:
        steps_per_external_eval = 5 * steps_per_eval

    if not hparams.attention:
        model_creator = nmt_model.Model
    elif hparams.attention_architecture == "standard":
        model_creator = attention_model.AttentionModel
    elif hparams.attention_architecture in ["gnmt", "gnmt_v2"]:
        model_creator = gnmt_model.GNMTModel
    else:
        raise ValueError("Unknown model architecture")

    train_model = model_helper.create_train_model(model_creator, hparams,
                                                  scope)
    eval_model = model_helper.create_eval_model(model_creator, hparams, scope)
    infer_model = model_helper.create_infer_model(model_creator, hparams,
                                                  scope)

    # Preload data for sample decoding.
    dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src)
    dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt)
    sample_src_data = inference.load_data(dev_src_file)
    sample_tgt_data = inference.load_data(dev_tgt_file)

    summary_name = "train_log"
    model_dir = hparams.out_dir

    # Log and output files
    log_file = os.path.join(out_dir, "log_%d" % time.time())
    log_f = tf.gfile.GFile(log_file, mode="w")
    utils.print_out("# log_file=%s" % log_file, log_f)

    avg_step_time = 0.0

    # TensorFlow model
    config_proto = utils.get_config_proto(
        log_device_placement=log_device_placement)

    train_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=train_model.graph)
    eval_sess = tf.Session(target=target_session,
                           config=config_proto,
                           graph=eval_model.graph)
    infer_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=infer_model.graph)

    with train_model.graph.as_default():
        loaded_train_model, global_step = model_helper.create_or_load_model(
            train_model.model, model_dir, train_sess, "train")

    # Summary writer
    summary_writer = tf.summary.FileWriter(os.path.join(out_dir, summary_name),
                                           train_model.graph)

    # First evaluation
    run_full_eval(model_dir, infer_model, infer_sess, eval_model, eval_sess,
                  hparams, summary_writer, sample_src_data, sample_tgt_data)

    last_stats_step = global_step
    last_eval_step = global_step
    last_external_eval_step = global_step

    # This is the training loop.
    step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0
    checkpoint_total_count = 0.0
    speed, train_ppl = 0.0, 0.0
    start_train_time = time.time()

    utils.print_out(
        "# Start step %d, lr %g, %s" %
        (global_step, loaded_train_model.learning_rate.eval(
            session=train_sess), time.ctime()), log_f)

    # Initialize all of the iterators
    skip_count = hparams.batch_size * hparams.epoch_step
    utils.print_out("# Init train iterator, skipping %d elements" % skip_count)

    if hparams.curriculum == 'none':
        train_sess.run(
            train_model.iterator.initializer,
            feed_dict={train_model.skip_count_placeholder: skip_count})
    else:
        if hparams.curriculum == 'predictive_gain':
            exp3s = Exp3S(hparams.num_curriculum_buckets, 0.001, 0, 0.05)
        elif hparams.curriculum == 'look_back_and_forward':
            curriculum_point = 0

        handle = train_model.iterator.handle
        for i in range(hparams.num_curriculum_buckets):
            train_sess.run(
                train_model.iterator.initializer[i].initializer,
                feed_dict={train_model.skip_count_placeholder: skip_count})

        iterator_handles = [
            train_sess.run(
                train_model.iterator.initializer[i].string_handle(),
                feed_dict={train_model.skip_count_placeholder: skip_count})
            for i in range(hparams.num_curriculum_buckets)
        ]

    utils.print_out("Starting training")

    while global_step < num_train_steps:
        ### Run a step ###
        start_time = time.time()
        try:
            if hparams.curriculum != 'none':
                if hparams.curriculum == 'predictive_gain':
                    lesson = exp3s.draw_task()
                elif hparams.curriculum == 'look_back_and_forward':
                    if curriculum_point == hparams.num_curriculum_buckets:
                        lesson = np.random.randint(
                            low=0, high=hparams.num_curriculum_buckets)
                    else:
                        lesson = curriculum_point if np.random.random_sample(
                        ) < 0.8 else np.random.randint(
                            low=0, high=hparams.num_curriculum_buckets)

                step_result = loaded_train_model.train(
                    hparams,
                    train_sess,
                    handle=handle,
                    iterator_handle=iterator_handles[lesson],
                    use_fed_source_placeholder=loaded_train_model.
                    use_fed_source,
                    fed_source_placeholder=loaded_train_model.fed_source)

                (_, step_loss, step_predict_count, step_summary, global_step,
                 step_word_count, batch_size, source) = step_result

                if hparams.curriculum == 'predictive_gain':
                    new_loss = train_sess.run(
                        [loaded_train_model.train_loss],
                        feed_dict={
                            handle: iterator_handles[lesson],
                            loaded_train_model.use_fed_source: True,
                            loaded_train_model.fed_source: source
                        })

                    # new_loss = loaded_train_model.train_loss.eval(
                    #   session=train_sess,
                    #   feed_dict={
                    #     handle: iterator_handles[lesson],
                    #     loaded_train_model.use_fed_source: True,
                    #     loaded_train_model.fed_source: source
                    #   })

                    # utils.print_out("lesson: %s, step loss: %s, new_loss: %s" % (lesson, step_loss, new_loss))
                    # utils.print_out("exp3s dist: %s" % (exp3s.pi, ))

                    curriculum_point_a = lesson * (
                        hparams.src_max_len //
                        hparams.num_curriculum_buckets) + 1
                    curriculum_point_b = (
                        lesson + 1) * (hparams.src_max_len //
                                       hparams.num_curriculum_buckets) + 1

                    v = step_loss - new_loss
                    exp3s.update_w(
                        v,
                        float(curriculum_point_a + curriculum_point_b) / 2.0)
                elif hparams.curriculum == 'look_back_and_forward':
                    utils.print_out("step loss: %s, lesson: %s" %
                                    (step_loss, lesson))
                    curriculum_point_a = curriculum_point * (
                        hparams.src_max_len //
                        hparams.num_curriculum_buckets) + 1
                    curriculum_point_b = (curriculum_point + 1) * (
                        hparams.src_max_len //
                        hparams.num_curriculum_buckets) + 1

                    if step_loss < (hparams.curriculum_progress_loss *
                                    (float(curriculum_point_a +
                                           curriculum_point_b) / 2.0)):
                        curriculum_point += 1
            else:
                step_result = loaded_train_model.train(hparams, train_sess)
                (_, step_loss, step_predict_count, step_summary, global_step,
                 step_word_count, batch_size) = step_result
            hparams.epoch_step += 1
        except tf.errors.OutOfRangeError:
            # Finished going through the training dataset.  Go to next epoch.
            hparams.epoch_step = 0
            # utils.print_out(
            #     "# Finished an epoch, step %d. Perform external evaluation" %
            #     global_step)
            # run_sample_decode(infer_model, infer_sess,
            #                   model_dir, hparams, summary_writer, sample_src_data,
            #                   sample_tgt_data)
            # dev_scores, test_scores, _ = run_external_eval(
            #     infer_model, infer_sess, model_dir,
            #     hparams, summary_writer)
            if hparams.curriculum == 'none':
                train_sess.run(
                    train_model.iterator.initializer,
                    feed_dict={train_model.skip_count_placeholder: 0})
            else:
                train_sess.run(
                    train_model.iterator.initializer[lesson].initializer,
                    feed_dict={train_model.skip_count_placeholder: 0})
            continue

        # Write step summary.
        summary_writer.add_summary(step_summary, global_step)

        # update statistics
        step_time += (time.time() - start_time)

        checkpoint_loss += (step_loss * batch_size)
        checkpoint_predict_count += step_predict_count
        checkpoint_total_count += float(step_word_count)

        # Once in a while, we print statistics.
        if global_step - last_stats_step >= steps_per_stats:
            if hparams.curriculum == 'predictive_gain':
                utils.print_out("lesson: %s, step loss: %s, new_loss: %s" %
                                (lesson, step_loss, new_loss))
                utils.print_out("exp3s dist: %s" % (exp3s.pi, ))

            last_stats_step = global_step

            # Print statistics for the previous epoch.
            avg_step_time = step_time / steps_per_stats
            train_ppl = utils.safe_exp(checkpoint_loss /
                                       checkpoint_predict_count)
            speed = checkpoint_total_count / (1000 * step_time)
            utils.print_out(
                "  global step %d lr %g "
                "step-time %.2fs wps %.2fK ppl %.2f %s" %
                (global_step,
                 loaded_train_model.learning_rate.eval(session=train_sess),
                 avg_step_time, speed, train_ppl, _get_best_results(hparams)),
                log_f)

            if math.isnan(train_ppl):
                break

            # Reset timer and loss.
            step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0
            checkpoint_total_count = 0.0

        if global_step - last_eval_step >= steps_per_eval:
            last_eval_step = global_step

            utils.print_out("# Save eval, global step %d" % global_step)
            utils.add_summary(summary_writer, global_step, "train_ppl",
                              train_ppl)

            # Save checkpoint
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(out_dir,
                                                       "translate.ckpt"),
                                          global_step=global_step)

            # Evaluate on dev/test
            run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                              summary_writer, sample_src_data, sample_tgt_data)
            dev_ppl, test_ppl = run_internal_eval(eval_model, eval_sess,
                                                  model_dir, hparams,
                                                  summary_writer)

            dev_scores, test_scores, _ = run_external_eval(
                infer_model, infer_sess, model_dir, hparams, summary_writer)

        # if global_step - last_external_eval_step >= steps_per_external_eval:
        #   last_external_eval_step = global_step

        #   # Save checkpoint
        #   loaded_train_model.saver.save(
        #       train_sess,
        #       os.path.join(out_dir, "translate.ckpt"),
        #       global_step=global_step)
        #   run_sample_decode(infer_model, infer_sess,
        #                     model_dir, hparams, summary_writer, sample_src_data,
        #                     sample_tgt_data)
        #   dev_scores, test_scores, _ = run_external_eval(
        #       infer_model, infer_sess, model_dir,
        #       hparams, summary_writer)

    # Done training
    loaded_train_model.saver.save(train_sess,
                                  os.path.join(out_dir, "translate.ckpt"),
                                  global_step=global_step)

    result_summary, _, dev_scores, test_scores, dev_ppl, test_ppl = run_full_eval(
        model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams,
        summary_writer, sample_src_data, sample_tgt_data)

    utils.print_out(
        "# Final, step %d lr %g "
        "step-time %.2f wps %.2fK ppl %.2f, %s, %s" %
        (global_step,
         loaded_train_model.learning_rate.eval(session=train_sess),
         avg_step_time, speed, train_ppl, result_summary, time.ctime()), log_f)
    utils.print_time("# Done training!", start_train_time)

    utils.print_out("# Start evaluating saved best models.")
    for metric in hparams.metrics:
        best_model_dir = getattr(hparams, "best_" + metric + "_dir")
        result_summary, best_global_step, _, _, _, _ = run_full_eval(
            best_model_dir, infer_model, infer_sess, eval_model, eval_sess,
            hparams, summary_writer, sample_src_data, sample_tgt_data)
        utils.print_out(
            "# Best %s, step %d "
            "step-time %.2f wps %.2fK, %s, %s" %
            (metric, best_global_step, avg_step_time, speed, result_summary,
             time.ctime()), log_f)

    summary_writer.close()
    return (dev_scores, test_scores, dev_ppl, test_ppl, global_step)
def main(_):
    default_hparams = nmt.create_hparams(FLAGS)
    ## Train / Decode
    out_dir = FLAGS.out_dir
    if not tf.gfile.Exists(out_dir): tf.gfile.MakeDirs(out_dir)

    # Load hparams.
    hparams = nmt.create_or_load_hparams(out_dir,
                                         default_hparams,
                                         FLAGS.hparams_path,
                                         save_hparams=False)

    log_device_placement = hparams.log_device_placement
    out_dir = hparams.out_dir
    num_train_steps = hparams.num_train_steps
    steps_per_stats = hparams.steps_per_stats
    avg_ckpts = hparams.avg_ckpts

    if not hparams.attention:
        model_creator = nmt_model.Model
    else:  # Attention
        if (hparams.encoder_type == "gnmt"
                or hparams.attention_architecture in ["gnmt", "gnmt_v2"]):
            model_creator = gnmt_model.GNMTModel
        elif hparams.attention_architecture == "standard":
            model_creator = attention_model.AttentionModel
        else:
            raise ValueError("Unknown attention architecture %s" %
                             hparams.attention_architecture)

    train_model =\
        model_helper.create_train_model(model_creator, hparams, scope=None)

    config_proto = utils.get_config_proto(
        log_device_placement=log_device_placement,
        num_intra_threads=1,
        num_inter_threads=36)

    def run(train_sess, num_workers, worker_id, num_replicas_per_worker):

        # Random
        random_seed = FLAGS.random_seed
        if random_seed is not None and random_seed > 0:
            utils.print_out("# Set random seed to %d" % random_seed)
            random.seed(random_seed + worker_id)
            np.random.seed(random_seed + worker_id)

        # Log and output files
        log_file = os.path.join(out_dir, "log_%d" % time.time())
        log_f = tf.gfile.GFile(log_file, mode="a")
        utils.print_out("# log_file=%s" % log_file, log_f)

        global_step = train_sess.run(train_model.model.global_step)[0]
        last_stats_step = global_step

        # This is the training loop.
        stats, info, start_train_time = before_train(train_model, train_sess,
                                                     global_step, hparams,
                                                     log_f,
                                                     num_replicas_per_worker)

        epoch_steps = FLAGS.epoch_size / (FLAGS.batch_size * num_workers *
                                          num_replicas_per_worker)

        for i in range(FLAGS.max_steps):
            ### Run a step ###
            start_time = time.time()
            if hparams.epoch_step != 0 and hparams.epoch_step % epoch_steps == 0:
                hparams.epoch_step = 0
                skip_count = train_model.skip_count_placeholder
                feed_dict = {}
                feed_dict[skip_count] = [
                    0 for i in range(num_replicas_per_worker)
                ]
                init = train_model.iterator.initializer
                train_sess.run(init, feed_dict=feed_dict)

            if worker_id == 0:
                results = train_sess.run([
                    train_model.model.update, train_model.model.train_loss,
                    train_model.model.predict_count,
                    train_model.model.train_summary,
                    train_model.model.global_step,
                    train_model.model.word_count, train_model.model.batch_size,
                    train_model.model.grad_norm,
                    train_model.model.learning_rate
                ])
                step_result = [r[0] for r in results]

            else:
                global_step, _ = train_sess.run(
                    [train_model.model.global_step, train_model.model.update])
            hparams.epoch_step += 1

            if worker_id == 0:
                # Process step_result, accumulate stats, and write summary
                global_step, info["learning_rate"], step_summary = \
                    train.update_stats(stats, start_time, step_result)

                # Once in a while, we print statistics.
                if global_step - last_stats_step >= steps_per_stats:
                    last_stats_step = global_step
                    is_overflow = train.process_stats(stats, info, global_step,
                                                      steps_per_stats, log_f)
                    train.print_step_info("  ", global_step, info,
                                          train._get_best_results(hparams),
                                          log_f)
                    if is_overflow:
                        break

                    # Reset statistics
                    stats = train.init_stats()

    sess, num_workers, worker_id, num_replicas_per_worker = \
        parallax.parallel_run(train_model.graph,
                              FLAGS.resource_info_file,
                              sync=FLAGS.sync,
                              parallax_config=parallax_config.build_config())
    run(sess, num_workers, worker_id, num_replicas_per_worker)
Esempio n. 9
0
def train(hparams, scope=None):

    model_dir = hparams.out_dir
    avg_ckpts = hparams.avg_ckpts
    steps_per_stats = hparams.steps_per_stats
    steps_per_external_eval = hparams.steps_per_external_eval
    steps_per_eval = 10 * steps_per_stats
    if not steps_per_external_eval:
        steps_per_external_eval = 5 * steps_per_eval
    summary_name = "summary"

    model_creator = gnmt_model.GNMTModel
    train_model = model_helper.create_train_model(model_creator, hparams)
    eval_model = model_helper.create_eval_model(model_creator, hparams)
    infer_model = model_helper.create_infer_model(model_creator, hparams)

    config_proto = tf.ConfigProto()
    config_proto.gpu_options.allow_growth = True
    train_sess = tf.Session(graph=train_model.graph, config=config_proto)
    eval_sess = tf.Session(graph=eval_model.graph, config=config_proto)
    infer_sess = tf.Session(graph=infer_model.graph, config=config_proto)

    with train_model.graph.as_default():
        loaded_train_model, global_step = model_helper.create_or_load_model(
            train_model.model, model_dir, train_sess, "train")

    # Summary writer
    summary_writer = tf.summary.FileWriter(
        os.path.join(model_dir, summary_name), train_model.graph)

    # Preload data for sample decoding.
    dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src)
    dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt)
    sample_src_data = utils.load_data(dev_src_file)
    sample_tgt_data = utils.load_data(dev_tgt_file)

    # First evaluation
    result_summary, _, _ = run_full_eval(model_dir, infer_model, infer_sess,
                                         eval_model, eval_sess, hparams,
                                         summary_writer, sample_src_data,
                                         sample_tgt_data, avg_ckpts)
    utils.log('First evaluation: {}'.format(result_summary))

    last_stats_step = global_step
    last_eval_step = global_step
    last_external_eval_step = global_step

    # This is the training loop.
    stats = init_stats()
    info = {
        "train_ppl": 0.0,
        "speed": 0.0,
        "avg_step_time": 0.0,
        "avg_grad_norm": 0.0,
        "learning_rate":
        loaded_train_model.learning_rate.eval(session=train_sess)
    }
    utils.log("Start step %d, lr %g" % (global_step, info["learning_rate"]))

    # Initialize all of the iterators
    train_sess.run(train_model.iterator.initializer)

    epoch = 1

    while True:
        ### Run a step ###
        start_time = time.time()
        try:
            step_result = loaded_train_model.train(train_sess)
        except tf.errors.OutOfRangeError:
            # Finished going through the training dataset.  Go to next epoch.
            utils.log(
                "Finished epoch %d, step %d. Perform external evaluation" %
                (epoch, global_step))
            run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                              summary_writer, sample_src_data, sample_tgt_data)
            run_external_eval(infer_model, infer_sess, model_dir, hparams,
                              summary_writer)
            if avg_ckpts:
                run_avg_external_eval(infer_model, infer_sess, model_dir,
                                      hparams, summary_writer, global_step)
            train_sess.run(train_model.iterator.initializer)

            if epoch < hparams.epochs:
                epoch += 1
                continue
            else:
                break

        # Process step_result, accumulate stats, and write summary
        global_step, info["learning_rate"], step_summary = update_stats(
            stats, start_time, step_result)
        summary_writer.add_summary(step_summary, global_step)

        # Once in a while, we print statistics.
        if global_step - last_stats_step >= steps_per_stats:
            last_stats_step = global_step

            is_overflow = process_stats(stats, info, global_step,
                                        steps_per_stats)
            print_step_info("  ", global_step, info,
                            "BLEU %.2f" % (hparams.best_bleu, ))
            if is_overflow:
                break

            # Reset statistics
            stats = init_stats()

        if global_step - last_eval_step >= steps_per_eval:
            last_eval_step = global_step

            utils.log("Save eval, global step %d" % (global_step, ))
            utils.add_summary(summary_writer, global_step, "train_ppl",
                              info["train_ppl"])

            # Save checkpoint
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(model_dir,
                                                       "translate.ckpt"),
                                          global_step=global_step)

            # Evaluate on dev/test
            run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                              summary_writer, sample_src_data, sample_tgt_data)
            run_internal_eval(eval_model, eval_sess, model_dir, hparams,
                              summary_writer)

        if global_step - last_external_eval_step >= steps_per_external_eval:
            last_external_eval_step = global_step

            # Save checkpoint
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(model_dir,
                                                       "translate.ckpt"),
                                          global_step=global_step)
            run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                              summary_writer, sample_src_data, sample_tgt_data)
            run_external_eval(infer_model, infer_sess, model_dir, hparams,
                              summary_writer)
            if avg_ckpts:
                run_avg_external_eval(infer_model, infer_sess, model_dir,
                                      hparams, summary_writer, global_step)

    # Done training
    loaded_train_model.saver.save(train_sess,
                                  os.path.join(model_dir, "translate.ckpt"),
                                  global_step=global_step)

    (result_summary, _, final_eval_metrics) = run_full_eval(
        model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams,
        summary_writer, sample_src_data, sample_tgt_data, avg_ckpts)
    print_step_info("Final, ", global_step, info, result_summary)
    utils.log("Done training!")

    summary_writer.close()

    utils.log("Start evaluating saved best models.")
    best_model_dir = hparams.best_bleu_dir
    summary_writer = tf.summary.FileWriter(
        os.path.join(best_model_dir, summary_name), infer_model.graph)
    result_summary, best_global_step, _ = run_full_eval(
        best_model_dir, infer_model, infer_sess, eval_model, eval_sess,
        hparams, summary_writer, sample_src_data, sample_tgt_data)
    print_step_info("Best BLEU, ", best_global_step, info, result_summary)
    summary_writer.close()

    if avg_ckpts:
        best_model_dir = hparams.avg_best_bleu_dir
        summary_writer = tf.summary.FileWriter(
            os.path.join(best_model_dir, summary_name), infer_model.graph)
        result_summary, best_global_step, _ = run_full_eval(
            best_model_dir, infer_model, infer_sess, eval_model, eval_sess,
            hparams, summary_writer, sample_src_data, sample_tgt_data)
        print_step_info("Averaged Best BLEU, ", best_global_step, info,
                        result_summary)
        summary_writer.close()

    return final_eval_metrics, global_step
Esempio n. 10
0
def train(hparams):
    num_epochs = hparams.num_epochs
    num_ckpt_epochs = hparams.num_ckpt_epochs
    summary_name = "train_log"
    out_dir = hparams.out_dir
    model_dir = out_dir
    log_device_placement = hparams.log_device_placement

    input_emb_weights = np.loadtxt(
        hparams.input_emb_file,
        delimiter=' ') if hparams.input_emb_file else None
    if hparams.model_architecture == "rnn-model": model_creator = model.RNN
    else:
        raise ValueError(
            "Unknown model architecture. Only simple_rnn is supported so far.")
    #create 2  models in 2 graphs for train and evaluation, with 2 sessions sharing the same variables.
    train_model = model_helper.create_train_model(
        model_creator,
        hparams,
        hparams.train_input_path,
        hparams.train_target_path,
        mode=tf.contrib.learn.ModeKeys.TRAIN)
    eval_model = model_helper.create_eval_model(model_creator, hparams,
                                                tf.contrib.learn.ModeKeys.EVAL)

    # some configuration of gpus logging
    config_proto = utils.get_config_proto(
        log_device_placement=log_device_placement, allow_soft_placement=True)
    # create two separate sessions for trai/eval
    train_sess = tf.Session(config=config_proto, graph=train_model.graph)
    eval_sess = tf.Session(config=config_proto, graph=eval_model.graph)

    # create a new train model by initializing all variables of the train graph in the train_sess.
    # or, using the latest checkpoint in the model_dir, load all variables of the train graph in the train_sess.
    # Note that at this point, the eval graph variables are not initialized.
    with train_model.graph.as_default():
        loaded_train_model = model_helper.create_or_load_model(
            train_model.model, train_sess, "train", model_dir,
            input_emb_weights)
    # create a log file with name summary_name in out_dir. The file is written asynchronously during the training process.
    # We also passed the train graph in order to be able to display it in Tensorboard
    summary_writer = tf.summary.FileWriter(os.path.join(out_dir, summary_name),
                                           train_model.graph)

    #run first evaluation before starting training
    dev_loss = run_evaluation(eval_model, eval_sess, model_dir,
                              hparams.val_input_path, hparams.val_target_path,
                              input_emb_weights, summary_writer)
    train_loss = run_evaluation(eval_model, eval_sess, model_dir,
                                hparams.train_input_path,
                                hparams.train_target_path, input_emb_weights,
                                summary_writer)
    print("Dev loss before training: %.3f" % dev_loss)
    print("Train loss before training: %.3f" % train_loss)
    # Start training
    start_train_time = time.time()
    epoch_time = 0.0
    batch_loss, epoch_loss = 0.0, 0.0
    batch_count = 0.0

    #initialize train iterator in train_sess
    train_sess.run(train_model.iterator.initializer)
    #keep lists of train/val losses for all epochs
    train_losses = []
    dev_losses = []
    #train the model for num_epochs. One epoch means a pass through the whole train dataset, i.e., through all the batches.
    for epoch in range(num_epochs):
        #go through all batches for the current epoch
        while True:
            start_batch_time = 0.0
            try:
                # this call will run operations of train graph in train_sess
                step_result = loaded_train_model.train(train_sess)
                (_, batch_loss, batch_summary, global_step, learning_rate,
                 batch_size, inputs, targets) = step_result
                epoch_time += (time.time() - start_batch_time)
                epoch_loss += batch_loss
                batch_count += 1
            except tf.errors.OutOfRangeError:
                #when the iterator of the train batches reaches the end, break the loop
                #and reinitialize the iterator to start from the beginning of the train data.
                train_sess.run(train_model.iterator.initializer)
                break
        # average epoch loss and epoch time over batches
        epoch_loss /= batch_count
        epoch_time /= batch_count
        batch_count = 0.0
        #print results if the current epoch is a print results epoch
        if (epoch + 1) % num_ckpt_epochs == 0:
            print("Saving checkpoint...")
            model_helper.add_summary(summary_writer, "train_loss", epoch_loss)
            # save checkpoint. We save the values of the variables of the train graph.
            # train_sess is the session in which the train graph was launched.
            # global_step parameter is optional and is appended to the name of the checkpoint.
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(out_dir, "rnn.ckpt"),
                                          global_step=epoch)

            print("Results: ")
            dev_loss = run_evaluation(eval_model, eval_sess, model_dir,
                                      hparams.val_input_path,
                                      hparams.val_target_path,
                                      input_emb_weights, summary_writer)
            # tr_loss = run_evaluation(eval_model, eval_sess, model_dir, hparams.train_input_path, hparams.train_target_path, input_emb_weights, summary_writer)
            # print("check %.3f:"%tr_loss)
            print(" epoch %d lr %g "
                  "train_loss %.3f, dev_loss %.3f" %
                  (epoch,
                   loaded_train_model.learning_rate.eval(session=train_sess),
                   epoch_loss, dev_loss))
            train_losses.append(epoch_loss)
            dev_losses.append(dev_loss)

    # save final model
    loaded_train_model.saver.save(train_sess,
                                  os.path.join(out_dir, "rnn.ckpt"),
                                  global_step=num_epochs)
    print("Done training in %.2fK" % (time.time() - start_train_time))
    min_dev_loss = np.min(dev_losses)
    min_dev_idx = np.argmin(dev_losses)
    print("Min val loss: %.3f at epoch %d" % (min_dev_loss, min_dev_idx))
    summary_writer.close()
Esempio n. 11
0
def train(hparams, scope=None, target_session=""):
  """Train a translation model."""
  log_device_placement = hparams.log_device_placement
  out_dir = hparams.out_dir
  num_train_steps = hparams.num_train_steps
  steps_per_stats = hparams.steps_per_stats
  steps_per_external_eval = hparams.steps_per_external_eval
  steps_per_eval = 10 * steps_per_stats
  if not steps_per_external_eval:
    steps_per_external_eval = 5 * steps_per_eval

  if not hparams.attention:
    model_creator = nmt_model.Model
  elif hparams.attention_architecture == "standard":
    model_creator = attention_model.AttentionModel
  elif hparams.attention_architecture in ["gnmt", "gnmt_v2"]:
    model_creator = gnmt_model.GNMTModel
  else:
    raise ValueError("Unknown model architecture")

  train_model = model_helper.create_train_model(model_creator, hparams, scope)
  eval_model = model_helper.create_eval_model(model_creator, hparams, scope)
  infer_model = model_helper.create_infer_model(model_creator, hparams, scope)

  # Preload data for sample decoding.
  dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src)
  dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt)
  sample_src_data = inference.load_data(dev_src_file)
  sample_tgt_data = inference.load_data(dev_tgt_file)

  summary_name = "train_log"
  model_dir = hparams.out_dir

  # Log and output files
  log_file = os.path.join(out_dir, "log_%d" % time.time())
  log_f = tf.gfile.GFile(log_file, mode="a")
  utils.print_out("# log_file=%s" % log_file, log_f)

  avg_step_time = 0.0

  # TensorFlow model
  config_proto = utils.get_config_proto(
      log_device_placement=log_device_placement)

  train_sess = tf.Session(
      target=target_session, config=config_proto, graph=train_model.graph)
  eval_sess = tf.Session(
      target=target_session, config=config_proto, graph=eval_model.graph)
  infer_sess = tf.Session(
      target=target_session, config=config_proto, graph=infer_model.graph)

  with train_model.graph.as_default():
    loaded_train_model, global_step = model_helper.create_or_load_model(
        train_model.model, model_dir, train_sess, "train")

  # Summary writer
  summary_writer = tf.summary.FileWriter(
      os.path.join(out_dir, summary_name), train_model.graph)

  # First evaluation
  run_full_eval(
      model_dir, infer_model, infer_sess,
      eval_model, eval_sess, hparams,
      summary_writer, sample_src_data,
      sample_tgt_data)

  last_stats_step = global_step
  last_eval_step = global_step
  last_external_eval_step = global_step

  # This is the training loop.
  step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0
  checkpoint_total_count = 0.0
  speed, train_ppl = 0.0, 0.0
  start_train_time = time.time()

  utils.print_out(
      "# Start step %d, lr %g, %s" %
      (global_step, loaded_train_model.learning_rate.eval(session=train_sess),
       time.ctime()),
      log_f)

  # Initialize all of the iterators
  skip_count = hparams.batch_size * hparams.epoch_step
  utils.print_out("# Init train iterator, skipping %d elements" % skip_count)
  train_sess.run(
      train_model.iterator.initializer,
      feed_dict={train_model.skip_count_placeholder: skip_count})

  while global_step < num_train_steps:
    ### Run a step ###
    start_time = time.time()
    try:
      step_result = loaded_train_model.train(train_sess)
      (_, step_loss, step_predict_count, step_summary, global_step,
       step_word_count, batch_size) = step_result
      hparams.epoch_step += 1
    except tf.errors.OutOfRangeError:
      # Finished going through the training dataset.  Go to next epoch.
      hparams.epoch_step = 0
      utils.print_out(
          "# Finished an epoch, step %d. Perform external evaluation" %
          global_step)
      run_sample_decode(infer_model, infer_sess,
                        model_dir, hparams, summary_writer, sample_src_data,
                        sample_tgt_data)
      dev_scores, test_scores, _ = run_external_eval(
          infer_model, infer_sess, model_dir,
          hparams, summary_writer)
      train_sess.run(
          train_model.iterator.initializer,
          feed_dict={train_model.skip_count_placeholder: 0})
      continue

    # Write step summary.
    summary_writer.add_summary(step_summary, global_step)

    # update statistics
    step_time += (time.time() - start_time)

    checkpoint_loss += (step_loss * batch_size)
    checkpoint_predict_count += step_predict_count
    checkpoint_total_count += float(step_word_count)

    # Once in a while, we print statistics.
    if global_step - last_stats_step >= steps_per_stats:
      last_stats_step = global_step

      # Print statistics for the previous epoch.
      avg_step_time = step_time / steps_per_stats
      train_ppl = utils.safe_exp(checkpoint_loss / checkpoint_predict_count)
      speed = checkpoint_total_count / (1000 * step_time)
      utils.print_out(
          "  global step %d lr %g "
          "step-time %.2fs wps %.2fK ppl %.2f %s" %
          (global_step,
           loaded_train_model.learning_rate.eval(session=train_sess),
           avg_step_time, speed, train_ppl, _get_best_results(hparams)),
          log_f)
      if math.isnan(train_ppl):
        break

      # Reset timer and loss.
      step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0
      checkpoint_total_count = 0.0

    if global_step - last_eval_step >= steps_per_eval:
      last_eval_step = global_step

      utils.print_out("# Save eval, global step %d" % global_step)
      utils.add_summary(summary_writer, global_step, "train_ppl", train_ppl)

      # Save checkpoint
      loaded_train_model.saver.save(
          train_sess,
          os.path.join(out_dir, "translate.ckpt"),
          global_step=global_step)

      # Evaluate on dev/test
      run_sample_decode(infer_model, infer_sess,
                        model_dir, hparams, summary_writer, sample_src_data,
                        sample_tgt_data)
      dev_ppl, test_ppl = run_internal_eval(
          eval_model, eval_sess, model_dir, hparams, summary_writer)

    if global_step - last_external_eval_step >= steps_per_external_eval:
      last_external_eval_step = global_step

      # Save checkpoint
      loaded_train_model.saver.save(
          train_sess,
          os.path.join(out_dir, "translate.ckpt"),
          global_step=global_step)
      run_sample_decode(infer_model, infer_sess,
                        model_dir, hparams, summary_writer, sample_src_data,
                        sample_tgt_data)
      dev_scores, test_scores, _ = run_external_eval(
          infer_model, infer_sess, model_dir,
          hparams, summary_writer)

  # Done training
  loaded_train_model.saver.save(
      train_sess,
      os.path.join(out_dir, "translate.ckpt"),
      global_step=global_step)

  result_summary, _, dev_scores, test_scores, dev_ppl, test_ppl = run_full_eval(
      model_dir, infer_model, infer_sess,
      eval_model, eval_sess, hparams,
      summary_writer, sample_src_data,
      sample_tgt_data)
  utils.print_out(
      "# Final, step %d lr %g "
      "step-time %.2f wps %.2fK ppl %.2f, %s, %s" %
      (global_step, loaded_train_model.learning_rate.eval(session=train_sess),
       avg_step_time, speed, train_ppl, result_summary, time.ctime()),
      log_f)
  utils.print_time("# Done training!", start_train_time)

  summary_writer.close()

  utils.print_out("# Start evaluating saved best models.")
  for metric in hparams.metrics:
    best_model_dir = getattr(hparams, "best_" + metric + "_dir")
    summary_writer = tf.summary.FileWriter(
        os.path.join(best_model_dir, summary_name), infer_model.graph)
    result_summary, best_global_step, _, _, _, _ = run_full_eval(
        best_model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams,
        summary_writer, sample_src_data, sample_tgt_data)
    utils.print_out("# Best %s, step %d "
                    "step-time %.2f wps %.2fK, %s, %s" %
                    (metric, best_global_step, avg_step_time, speed,
                     result_summary, time.ctime()), log_f)
    summary_writer.close()

  return (dev_scores, test_scores, dev_ppl, test_ppl, global_step)
Esempio n. 12
0
def train(hparams, ckpt_dir, scope=None, target_session="", alternative=False):
    out_dir = os.path.join(hparams.base_dir, "train")
    if not misc.check_file_existence(out_dir):
        tf.gfile.MakeDirs(out_dir)
    tf.logging.info("All train relevant results will be put in %s" % out_dir)
    num_train_steps = hparams.num_train_steps
    steps_per_stats = hparams.steps_per_stats
    ckpt_path = os.path.join(ckpt_dir, "model.ckpt")

    # Create model
    model_creator = get_model_creator(hparams.model_type)
    train_model = helper.create_train_model(model_creator, hparams, scope)
    config_proto = misc.get_config_proto(
        log_device_placement=False,
        allow_soft_placement=True,
        num_intra_threads=hparams.num_intra_threads,
        num_inter_threads=hparams.num_inter_threads)
    train_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=train_model.graph)
    tf.logging.info("Create model successfully")
    with train_model.graph.as_default():
        loaded_train_model, global_step = helper.create_or_load_model(
            train_model.model, ckpt_dir, train_sess, "train")
    # Summary writer
    summary_name = "train_summary"
    summary_path = os.path.join(out_dir, summary_name)
    if not tf.gfile.Exists(summary_path):
        tf.gfile.MakeDirs(summary_path)
    summary_writer = tf.summary.FileWriter(summary_path, train_model.graph)
    last_stats_step = global_step
    # Training iteration
    stats, info, start_train_time = before_train(loaded_train_model,
                                                 train_model, train_sess,
                                                 global_step)
    tf.logging.info("Ready to train")
    epoch_step = 0
    while global_step < num_train_steps:
        start_time = time.time()
        try:
            tf.logging.info("Start train epoch:%d" % epoch_step)
            step_result = loaded_train_model.train(train_sess)
            epoch_step += 1
        except tf.errors.OutOfRangeError:
            tf.logging.info("Saving epoch step %d model into checkpoint" %
                            epoch_step)
            loaded_train_model.saver.save(train_sess,
                                          ckpt_path,
                                          global_step=global_step)
            # Training while evaluating alternately
            if alternative:
                eval.evaluate(hparams,
                              scope=scope,
                              target_session=target_session,
                              global_step_=global_step,
                              ckpt_path=ckpt_path,
                              alternative=alternative)
            # Finished going through the training dataset.  Go to next epoch.
            epoch_step = 0
            tf.logging.info("# Finished an epoch, step %d." % global_step)
            train_sess.run(train_model.data_wrapper.initializer)
            continue

        global_step, info["learning_rate"], step_summary = update_stats(
            stats, start_time, step_result)
        summary_writer.add_summary(step_summary, global_step)
        # Once in a while, print statistics.
        if global_step - last_stats_step >= steps_per_stats:
            tf.logging.info(
                "In global step:%d, time to print train statistics" %
                global_step)
            last_stats_step = global_step
            # Update info
            process_stats(stats, info, steps_per_stats, hparams.img_batch_size)
            print_step_info(global_step, info, stats)
            # Reset statistic
            stats = init_stats()

    # Done training
    tf.logging.info("Finish training, saving model into checkpoint")
    loaded_train_model.saver.save(train_sess,
                                  ckpt_path,
                                  global_step=global_step)
    summary_writer.close()
Esempio n. 13
0
def train(hparams):
    """Train a sequence tagging model."""
    num_epochs = hparams.num_epochs
    num_ckpt_epochs = hparams.num_ckpt_epochs
    summary_name = "train_log"
    out_dir = hparams.out_dir
    model_dir = out_dir
    log_device_placement = hparams.log_device_placement

    # Load external embedding vectors if a file is given as input. You dont care about external embeddings now.
    input_emb_weights = np.loadtxt(
        hparams.input_emb_file,
        delimiter=' ') if hparams.input_emb_file else None
    if hparams.model_architecture == "simple_rnn":
        model_creator = model.RNN
    else:
        raise ValueError(
            "Unknown model architecture. Only simple_rnn is supported so far.")
    # create 2 models in 2 separate graphs for train and evaluation.
    train_model = model_helper.create_train_model(
        model_creator,
        hparams,
        hparams.train_input_path,
        hparams.train_target_path,
        mode=tf.contrib.learn.ModeKeys.TRAIN)
    eval_model = model_helper.create_eval_model(model_creator, hparams,
                                                tf.contrib.learn.ModeKeys.EVAL)

    # some configuration of gpus logging
    config_proto = utils.get_config_proto(
        log_device_placement=log_device_placement, allow_soft_placement=True)
    # create two separate sessions for train/evaluation.
    train_sess = tf.Session(config=config_proto, graph=train_model.graph)
    eval_sess = tf.Session(config=config_proto, graph=eval_model.graph)

    # create a new train model by initializing all variables of the train graph in the train_sess.
    # or, using the latest checkpoint in the model_dir, load all variables of the train graph in the train_sess.
    # Note that at this point, the eval graph variables are not initialized.
    with train_model.graph.as_default():
        loaded_train_model = model_helper.create_or_load_model(
            train_model.model, train_sess, "train", model_dir,
            input_emb_weights)
    # create a log file with name summary_name in the out_dir. The file is written asynchronously during the training process.
    # We also passed the train graph in order to be able to display it in Tensorboard.
    summary_writer = tf.summary.FileWriter(os.path.join(out_dir, summary_name),
                                           train_model.graph)

    # run first evaluation before starting training.
    val_loss, val_acc = run_evaluation(eval_model, eval_sess, model_dir,
                                       hparams.val_input_path,
                                       hparams.val_target_path,
                                       input_emb_weights, summary_writer)
    train_loss, train_acc = run_evaluation(eval_model, eval_sess, model_dir,
                                           hparams.train_input_path,
                                           hparams.train_target_path,
                                           input_emb_weights, summary_writer)
    print("Before training: Val loss %.3f, Val accuracy %.3f." %
          (val_loss, val_acc))
    print("Before training: Train loss %.3f Train acc %.3f" %
          (train_loss, train_acc))
    # Start training
    start_train_time = time.time()
    avg_batch_time = 0.0
    batch_loss, epoch_loss, epoch_accuracy = 0.0, 0.0, 0.0
    batch_count = 0.0

    # initialize train iterator in train_sess
    train_sess.run(train_model.iterator.initializer)
    # keep lists of train/val losses for all epochs.
    train_losses = []
    dev_losses = []

    # vars to compute timeline of operations. Timeline is useful to see how much time each operator on tf graph takes.
    # You dont care about this.
    options = None
    run_metadata = None
    if hparams.timeline:
        options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
        run_metadata = tf.RunMetadata()
    # train the model for num_epochs. One epoch means a pass through the whole train dataset, i.e., through all the batches.
    step = 0
    for epoch in range(num_epochs):
        # go through all batches for the current epoch.
        while True:
            start_batch_time = time.time()
            try:
                # You dont care about timeline now.
                if hparams.timeline and step % 10 == 0:
                    # this call will run operations of train graph in train_sess.
                    step_result = loaded_train_model.train(
                        train_sess, options=options, run_metadata=run_metadata)
                    summary_writer.add_run_metadata(run_metadata,
                                                    'step%d' % step)
                    fetched_timeline = timeline.Timeline(
                        run_metadata.step_stats)
                    chrome_trace = fetched_timeline.generate_chrome_trace_format(
                    )
                    if not tf.gfile.Exists(
                            os.path.join(
                                out_dir,
                                'timelines/timeline_02_step_%d.json')):
                        tf.gfile.MakeDirs(out_dir)
                    with open(
                            os.path.join(out_dir,
                                         'timelines/timeline_02_step_%d.json')
                            % step, 'w') as f:
                        f.write(chrome_trace)
                else:
                    # this call will run operations of train graph in train_sess.
                    step_result = loaded_train_model.train(train_sess,
                                                           options=None,
                                                           run_metadata=None)

                (_, batch_loss, batch_summary, global_step, learning_rate,
                 batch_size, batch_accuracy) = step_result
                avg_batch_time += (time.time() - start_batch_time)
                epoch_loss += batch_loss
                epoch_accuracy += batch_accuracy
                batch_count += 1
                step += 1
            except tf.errors.OutOfRangeError:
                # We went through all train batches and so, the iterator over the train batches reached the end.
                # We break the while loop and reinitialize the iterator to start from the beginning of the train data.
                train_sess.run(train_model.iterator.initializer)
                break
        # average epoch loss and epoch time over batches.
        epoch_loss /= batch_count
        avg_batch_time /= batch_count
        epoch_accuracy /= batch_count
        print("Number of batches: %d" % batch_count)
        # print results if the current epoch is a print results epoch
        if (epoch + 1) % num_ckpt_epochs == 0:
            print("Saving checkpoint...")
            model_helper.add_summary(summary_writer, "train_loss", epoch_loss)
            model_helper.add_summary(summary_writer, "train_accuracy",
                                     epoch_accuracy)
            # save checkpoint. We save the values of the variables of the train graph.
            # train_sess is the session in which the train graph was launched.
            # global_step parameter is optional and is appended to the name of the checkpoint.
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(out_dir, "rnn.ckpt"),
                                          global_step=epoch)

            print("Results: ")
            val_loss, val_accuracy = run_evaluation(
                eval_model, eval_sess, model_dir, hparams.val_input_path,
                hparams.val_target_path, input_emb_weights, summary_writer)
            print(
                " epoch %d lr %g "
                "train_loss %.3f, val_loss %.3f, train_accuracy %.3f, val accuracy %.3f, avg_batch_time %f"
                % (epoch,
                   loaded_train_model.learning_rate.eval(session=train_sess),
                   epoch_loss, val_loss, epoch_accuracy, val_accuracy,
                   avg_batch_time))
            train_losses.append(epoch_loss)
            dev_losses.append(val_loss)
        batch_count = 0.0
        avg_batch_time = 0.0
        epoch_loss = 0.0
        epoch_accuracy = 0.0

    # save final model
    loaded_train_model.saver.save(train_sess,
                                  os.path.join(out_dir, "rnn.ckpt"),
                                  global_step=num_epochs)
    print("Done training in %.2fK" % (time.time() - start_train_time))
    min_dev_loss = np.min(dev_losses)
    min_dev_idx = np.argmin(dev_losses)
    print("Min val loss: %.3f at epoch %d" % (min_dev_loss, min_dev_idx))
    summary_writer.close()