Exemplo n.º 1
0
def compute_perplexity(model, sess, name):
    """Compute perplexity of the output of the model.

    Args:
      model: model for compute perplexity.
      sess: tensorflow session to use.
      name: name of the batch.

    Returns:
      The perplexity of the eval outputs.
    """
    total_loss = 0
    total_predict_count = 0
    start_time = time.time()

    while True:
        try:
            loss, predict_count, batch_size = model.eval(sess)
            total_loss += loss * batch_size
            total_predict_count += predict_count
        except tf.errors.OutOfRangeError:
            break

    perplexity = utils.safe_exp(total_loss / total_predict_count)
    utils.print_time("  eval %s: perplexity %.2f" % (name, perplexity),
                     start_time)
    return perplexity
Exemplo n.º 2
0
def check_stats(stats, global_step, steps_per_stats, hparams, log_f):
    """Print statistics and also check for overflow."""
    # Print statistics for the previous epoch.
    avg_step_time = stats["step_time"] / steps_per_stats
    avg_grad_norm = stats["grad_norm"] / steps_per_stats
    train_ppl = utils.safe_exp(stats["loss"] / stats["predict_count"])
    speed = stats["total_count"] / (1000 * stats["step_time"])
    utils.print_out(
        "  global step %d lr %g "
        "step-time %.2fs wps %.2fK ppl %.2f gN %.2f %s" %
        (global_step, stats["learning_rate"], avg_step_time, speed, train_ppl,
         avg_grad_norm, _get_best_results(hparams)), log_f)

    # Check for overflow
    is_overflow = False
    if math.isnan(train_ppl) or math.isinf(train_ppl) or train_ppl > 1e20:
        utils.print_out("  step %d overflow, stop early" % global_step, log_f)
        is_overflow = True

    return is_overflow
Exemplo n.º 3
0
def run_eval(hparams, sess, eval_model, current_epoch, step_in_epoch):
    with eval_model.graph.as_default():
        saver = eval_model.saver
        latest_chkpt = tf.train.latest_checkpoint(hparams.out_dir)
        saver.restore(sess, latest_chkpt)
        if current_epoch == 0 and step_in_epoch == 100:
            sess.run(tf.tables_initializer())

    sess.run(eval_model.iterator.initializer)
    total_loss, total_predicted_count = 0.0, 0.0
    while True:
        try:
            loss, predicted_count, source, target_input, target_output = eval_model.model.eval(
                sess)
        except tf.errors.OutOfRangeError:
            print("eval finished.")
            break
        total_loss += loss * hparams.batch_size
        total_predicted_count += predicted_count
    ppl = utils.safe_exp(total_loss / total_predicted_count)
    print("epoch %d: step_in_epoch %d, eval  ppl %.2f" %
          (current_epoch, step_in_epoch, ppl))
    return ppl, total_loss
Exemplo n.º 4
0
def train(hparams, scope=None, target_session=''):
    """Train the chatbot"""
    # Initialize some local hyperparameters
    log_device_placement = hparams.log_device_placement
    out_dir = hparams.out_dir
    num_train_steps = hparams.num_train_steps
    steps_per_stats = hparams.steps_per_stats
    steps_per_external_eval = hparams.steps_per_external_eval
    steps_per_eval = 10 * steps_per_stats
    if not steps_per_external_eval:
        steps_per_external_eval = 5 * steps_per_eval

    if hparams.architecture == "simple":
        model_creator = SimpleModel
        get_infer_iterator = iterator_utils.get_infer_iterator
        get_iterator = iterator_utils.get_iterator
    elif hparams.architecture == "hier":
        model_creator = HierarchicalModel
        # Parse some of the arguments now
        def curry_get_infer_iterator(dataset, vocab_table, batch_size, src_reverse,
                       eos, src_max_len):
            return end2end_iterator_utils.get_infer_iterator(dataset, vocab_table, batch_size, src_reverse, eos,
                                                      src_max_len=src_max_len, eou=hparams.eou,
                                                      dialogue_max_len=hparams.dialogue_max_len)
        get_infer_iterator = curry_get_infer_iterator

        def curry_get_iterator(src_dataset,
                 tgt_dataset,
                 vocab_table,
                 batch_size,
                 sos,
                 eos,
                 src_reverse,
                 random_seed,
                 num_buckets,
                 src_max_len=None,
                 tgt_max_len=None,
                 num_threads=4,
                 output_buffer_size=None,
                 skip_count=None):
            return end2end_iterator_utils.get_iterator(src_dataset, tgt_dataset, vocab_table, batch_size, sos, eos,
                                                eou=hparams.eou, src_reverse=src_reverse, random_seed=random_seed,
                                                num_dialogue_buckets=num_buckets, src_max_len=src_max_len,
                                                tgt_max_len=tgt_max_len, num_threads=num_threads,
                                                output_buffer_size=output_buffer_size, skip_count=skip_count)

        get_iterator = curry_get_iterator
    else:
        raise ValueError("Unkown architecture", hparams.architecture)

    # Create three models which share parameters through the use of checkpoints
    train_model = create_train_model(model_creator, get_iterator, hparams, scope)
    eval_model = create_eval_model(model_creator, get_iterator, hparams, scope)
    infer_model = inference.create_infer_model(model_creator, get_infer_iterator, hparams, scope)
    # ToDo: adapt for architectures
    # Preload the data to use for sample decoding

    dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src)
    dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt)
    sample_src_data = inference.load_data(dev_src_file)
    sample_tgt_data = inference.load_data(dev_tgt_file)

    summary_name = "train_log"
    model_dir = hparams.out_dir

    # Log and output files
    log_file = os.path.join(out_dir, "log_%d" % time.time())
    log_f = tf.gfile.GFile(log_file, mode="a")
    utils.print_out("# log_file=%s" % log_file, log_f)

    avg_step_time = 0.0

    # Create the configurations for the sessions
    config_proto = utils.get_config_proto(log_device_placement=log_device_placement)
    # Create three sessions, one for each model
    train_sess = tf.Session(target=target_session, config=config_proto, graph=train_model.graph)
    eval_sess = tf.Session(target=target_session, config=config_proto, graph=eval_model.graph)
    infer_sess = tf.Session(target=target_session, config=config_proto, graph=infer_model.graph)

    # Load the train model from checkpoint or create a new one
    with train_model.graph.as_default():
        loaded_train_model, global_step = model_helper.create_or_load_model(train_model.model, model_dir,
                                                                            train_sess, name="train")

    # Summary writer
    summary_writer = tf.summary.FileWriter(
        os.path.join(out_dir, summary_name), train_model.graph)
    # First evaluation
    run_full_eval(
        model_dir, infer_model, infer_sess,
        eval_model, eval_sess, hparams,
        summary_writer, sample_src_data,
        sample_tgt_data)

    last_stats_step = global_step
    last_eval_step = global_step
    last_external_eval_step = global_step

    # This is the training loop.
    # Initialize the hyperparameters for the loop.
    step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0
    checkpoint_total_count = 0.0
    speed, train_ppl = 0.0, 0.0
    start_train_time = time.time()

    utils.print_out(
        "# Start step %d, lr %g, %s" %
        (global_step, loaded_train_model.learning_rate.eval(session=train_sess),
         time.ctime()),
        log_f)

    # epoch_step records where we were within an epoch. Used to skip trained on examples
    skip_count = hparams.batch_size * hparams.epoch_step
    utils.print_out("# Init train iterator, skipping %d elements" % skip_count)
    # Initialize the training iterator
    train_sess.run(
        train_model.iterator.initializer,
        feed_dict={train_model.skip_count_placeholder: skip_count})

    # Train until we reach num_steps.
    while global_step < num_train_steps:
        # Run a step
        start_step_time = time.time()
        try:
            step_result = loaded_train_model.train(train_sess)
            (_, step_loss, step_predict_count, step_summary, global_step,  # The _ is the output of the update op
             step_word_count, batch_size) = step_result
            hparams.epoch_step += 1
        except tf.errors.OutOfRangeError:
            # Finished going through the training dataset.  Go to next epoch.
            hparams.epoch_step = 0
            utils.print_out(
                "# Finished an epoch, step %d. Perform external evaluation" %
                global_step)
            # Decode and print a random sentence
            run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer,
                              sample_src_data, sample_tgt_data)
            # Perform external evaluation to save checkpoints if this is the best for some metric
            dev_scores, test_scores, _ = run_external_evaluation(infer_model, infer_sess, model_dir, hparams,
                                                                 summary_writer, save_on_best_dev=True)
            # Reinitialize the iterator from the beginning
            train_sess.run(train_model.iterator.initializer,
                           feed_dict={train_model.skip_count_placeholder: 0})
            continue

        # Write step summary.
        summary_writer.add_summary(step_summary, global_step)

        # update statistics
        step_time += (time.time() - start_step_time)

        checkpoint_loss += (step_loss * batch_size)
        checkpoint_predict_count += step_predict_count
        checkpoint_total_count += float(step_word_count)

        # Once in a while, we print statistics.
        if global_step - last_stats_step >= steps_per_stats:
            last_stats_step = global_step

            # Print statistics for the previous epoch.
            avg_step_time = step_time / steps_per_stats
            train_ppl = utils.safe_exp(checkpoint_loss / checkpoint_predict_count)
            speed = checkpoint_total_count / (1000 * step_time)
            utils.print_out(
                "  global step %d lr %g "
                "step-time %.2fs wps %.2fK ppl %.2f %s" %
                (global_step,
                 loaded_train_model.learning_rate.eval(session=train_sess),
                 avg_step_time, speed, train_ppl, _get_best_results(hparams)),
                log_f)
            if math.isnan(train_ppl):
                # The model has screwed up
                break

            # Reset timer and loss.
            step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0
            checkpoint_total_count = 0.0

        if global_step - last_eval_step >= steps_per_eval:
            # Perform evaluation. Start by reassigning the last_eval_step variable to the current step
            last_eval_step = global_step
            # Print the progress and add summary
            utils.print_out("# Save eval, global step %d" % global_step)
            utils.add_summary(summary_writer, global_step, "train_ppl", train_ppl)

            # Save checkpoint
            loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "chatbot.ckpt"), global_step=global_step)
            # Decode and print a random sample
            run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer,
                              sample_src_data, sample_tgt_data)
            # Run internal evaluation, and update the ppl variables. The data iterator is instantieted in the method.
            dev_ppl, test_ppl = run_internal_eval(eval_model, eval_sess, model_dir, hparams, summary_writer)

        if global_step - last_external_eval_step >= steps_per_external_eval:
            # Run the external evaluation
            last_external_eval_step = global_step
            # Save checkpoint
            loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "chatbot.ckpt"), global_step=global_step)
            # Decode and print a random sample
            run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer,
                              sample_src_data, sample_tgt_data)
            # Run external evaluation, updating metric scores in the meanwhile. The unneeded output is the global step.
            dev_scores, test_scores, _ = run_external_evaluation(infer_model, infer_sess, model_dir, hparams,
                                                                 summary_writer, save_on_best_dev=True)

    # Done training. Save the model
    loaded_train_model.saver.save(
        train_sess,
        os.path.join(out_dir, "chatbot.ckpt"),
        global_step=global_step)

    result_summary, _, dev_scores, test_scores, dev_ppl, test_ppl = run_full_eval(
        model_dir, infer_model, infer_sess,
        eval_model, eval_sess, hparams,
        summary_writer, sample_src_data,
        sample_tgt_data)
    utils.print_out(
        "# Final, step %d lr %g "
        "step-time %.2f wps %.2fK ppl %.2f, %s, %s" %
        (global_step, loaded_train_model.learning_rate.eval(session=train_sess),
         avg_step_time, speed, train_ppl, result_summary, time.ctime()),
        log_f)
    utils.print_time("# Done training!", start_train_time)

    utils.print_out("# Start evaluating saved best models.")
    for metric in hparams.metrics:
        best_model_dir = getattr(hparams, "best_" + metric + "_dir")
        result_summary, best_global_step, _, _, _, _ = run_full_eval(
            best_model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams,
            summary_writer, sample_src_data, sample_tgt_data)
        utils.print_out("# Best %s, step %d "
                        "step-time %.2f wps %.2fK, %s, %s" %
                        (metric, best_global_step, avg_step_time, speed,
                         result_summary, time.ctime()), log_f)

    summary_writer.close()
    return (dev_scores, test_scores, dev_ppl, test_ppl, global_step)
Exemplo n.º 5
0
def train(hparams, scope=None, target_session=""):
    """Train a translation model."""
    log_device_placement = hparams.log_device_placement
    out_dir = hparams.out_dir
    num_train_steps = hparams.num_train_steps
    steps_per_stats = hparams.steps_per_stats
    steps_per_external_eval = hparams.steps_per_external_eval
    steps_per_eval = 10 * steps_per_stats
    if not steps_per_external_eval:
        steps_per_external_eval = 5 * steps_per_eval

    if not hparams.attention:
        model_creator = nmt_model.Model
    elif hparams.attention_architecture == "standard":
        model_creator = attention_model.AttentionModel
    elif hparams.attention_architecture in ["gnmt", "gnmt_v2"]:
        model_creator = gnmt_model.GNMTModel
    else:
        raise ValueError("Unknown model architecture")

    train_model = model_helper.create_train_model(model_creator, hparams,
                                                  scope)
    eval_model = model_helper.create_eval_model(model_creator, hparams, scope)
    infer_model = model_helper.create_infer_model(model_creator, hparams,
                                                  scope)

    # Preload data for sample decoding.
    dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src)
    dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt)
    sample_src_data = inference.load_data(dev_src_file)
    sample_tgt_data = inference.load_data(dev_tgt_file)

    summary_name = "train_log"
    model_dir = hparams.out_dir

    # Log and output files
    log_file = os.path.join(out_dir, "log_%d" % time.time())
    log_f = tf.gfile.GFile(log_file, mode="w")
    utils.print_out("# log_file=%s" % log_file, log_f)

    avg_step_time = 0.0

    # TensorFlow model
    config_proto = utils.get_config_proto(
        log_device_placement=log_device_placement)

    train_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=train_model.graph)
    eval_sess = tf.Session(target=target_session,
                           config=config_proto,
                           graph=eval_model.graph)
    infer_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=infer_model.graph)

    with train_model.graph.as_default():
        loaded_train_model, global_step = model_helper.create_or_load_model(
            train_model.model, model_dir, train_sess, "train")

    # Summary writer
    summary_writer = tf.summary.FileWriter(os.path.join(out_dir, summary_name),
                                           train_model.graph)

    # First evaluation
    run_full_eval(model_dir, infer_model, infer_sess, eval_model, eval_sess,
                  hparams, summary_writer, sample_src_data, sample_tgt_data)

    last_stats_step = global_step
    last_eval_step = global_step
    last_external_eval_step = global_step

    # This is the training loop.
    step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0
    checkpoint_total_count = 0.0
    speed, train_ppl = 0.0, 0.0
    start_train_time = time.time()

    utils.print_out(
        "# Start step %d, lr %g, %s" %
        (global_step, loaded_train_model.learning_rate.eval(
            session=train_sess), time.ctime()), log_f)

    # Initialize all of the iterators
    skip_count = hparams.batch_size * hparams.epoch_step
    utils.print_out("# Init train iterator, skipping %d elements" % skip_count)

    if hparams.curriculum == 'none':
        train_sess.run(
            train_model.iterator.initializer,
            feed_dict={train_model.skip_count_placeholder: skip_count})
    else:
        if hparams.curriculum == 'predictive_gain':
            exp3s = Exp3S(hparams.num_curriculum_buckets, 0.001, 0, 0.05)
        elif hparams.curriculum == 'look_back_and_forward':
            curriculum_point = 0

        handle = train_model.iterator.handle
        for i in range(hparams.num_curriculum_buckets):
            train_sess.run(
                train_model.iterator.initializer[i].initializer,
                feed_dict={train_model.skip_count_placeholder: skip_count})

        iterator_handles = [
            train_sess.run(
                train_model.iterator.initializer[i].string_handle(),
                feed_dict={train_model.skip_count_placeholder: skip_count})
            for i in range(hparams.num_curriculum_buckets)
        ]

    utils.print_out("Starting training")

    while global_step < num_train_steps:
        ### Run a step ###
        start_time = time.time()
        try:
            if hparams.curriculum != 'none':
                if hparams.curriculum == 'predictive_gain':
                    lesson = exp3s.draw_task()
                elif hparams.curriculum == 'look_back_and_forward':
                    if curriculum_point == hparams.num_curriculum_buckets:
                        lesson = np.random.randint(
                            low=0, high=hparams.num_curriculum_buckets)
                    else:
                        lesson = curriculum_point if np.random.random_sample(
                        ) < 0.8 else np.random.randint(
                            low=0, high=hparams.num_curriculum_buckets)

                step_result = loaded_train_model.train(
                    hparams,
                    train_sess,
                    handle=handle,
                    iterator_handle=iterator_handles[lesson],
                    use_fed_source_placeholder=loaded_train_model.
                    use_fed_source,
                    fed_source_placeholder=loaded_train_model.fed_source)

                (_, step_loss, step_predict_count, step_summary, global_step,
                 step_word_count, batch_size, source) = step_result

                if hparams.curriculum == 'predictive_gain':
                    new_loss = train_sess.run(
                        [loaded_train_model.train_loss],
                        feed_dict={
                            handle: iterator_handles[lesson],
                            loaded_train_model.use_fed_source: True,
                            loaded_train_model.fed_source: source
                        })

                    # new_loss = loaded_train_model.train_loss.eval(
                    #   session=train_sess,
                    #   feed_dict={
                    #     handle: iterator_handles[lesson],
                    #     loaded_train_model.use_fed_source: True,
                    #     loaded_train_model.fed_source: source
                    #   })

                    # utils.print_out("lesson: %s, step loss: %s, new_loss: %s" % (lesson, step_loss, new_loss))
                    # utils.print_out("exp3s dist: %s" % (exp3s.pi, ))

                    curriculum_point_a = lesson * (
                        hparams.src_max_len //
                        hparams.num_curriculum_buckets) + 1
                    curriculum_point_b = (
                        lesson + 1) * (hparams.src_max_len //
                                       hparams.num_curriculum_buckets) + 1

                    v = step_loss - new_loss
                    exp3s.update_w(
                        v,
                        float(curriculum_point_a + curriculum_point_b) / 2.0)
                elif hparams.curriculum == 'look_back_and_forward':
                    utils.print_out("step loss: %s, lesson: %s" %
                                    (step_loss, lesson))
                    curriculum_point_a = curriculum_point * (
                        hparams.src_max_len //
                        hparams.num_curriculum_buckets) + 1
                    curriculum_point_b = (curriculum_point + 1) * (
                        hparams.src_max_len //
                        hparams.num_curriculum_buckets) + 1

                    if step_loss < (hparams.curriculum_progress_loss *
                                    (float(curriculum_point_a +
                                           curriculum_point_b) / 2.0)):
                        curriculum_point += 1
            else:
                step_result = loaded_train_model.train(hparams, train_sess)
                (_, step_loss, step_predict_count, step_summary, global_step,
                 step_word_count, batch_size) = step_result
            hparams.epoch_step += 1
        except tf.errors.OutOfRangeError:
            # Finished going through the training dataset.  Go to next epoch.
            hparams.epoch_step = 0
            # utils.print_out(
            #     "# Finished an epoch, step %d. Perform external evaluation" %
            #     global_step)
            # run_sample_decode(infer_model, infer_sess,
            #                   model_dir, hparams, summary_writer, sample_src_data,
            #                   sample_tgt_data)
            # dev_scores, test_scores, _ = run_external_eval(
            #     infer_model, infer_sess, model_dir,
            #     hparams, summary_writer)
            if hparams.curriculum == 'none':
                train_sess.run(
                    train_model.iterator.initializer,
                    feed_dict={train_model.skip_count_placeholder: 0})
            else:
                train_sess.run(
                    train_model.iterator.initializer[lesson].initializer,
                    feed_dict={train_model.skip_count_placeholder: 0})
            continue

        # Write step summary.
        summary_writer.add_summary(step_summary, global_step)

        # update statistics
        step_time += (time.time() - start_time)

        checkpoint_loss += (step_loss * batch_size)
        checkpoint_predict_count += step_predict_count
        checkpoint_total_count += float(step_word_count)

        # Once in a while, we print statistics.
        if global_step - last_stats_step >= steps_per_stats:
            if hparams.curriculum == 'predictive_gain':
                utils.print_out("lesson: %s, step loss: %s, new_loss: %s" %
                                (lesson, step_loss, new_loss))
                utils.print_out("exp3s dist: %s" % (exp3s.pi, ))

            last_stats_step = global_step

            # Print statistics for the previous epoch.
            avg_step_time = step_time / steps_per_stats
            train_ppl = utils.safe_exp(checkpoint_loss /
                                       checkpoint_predict_count)
            speed = checkpoint_total_count / (1000 * step_time)
            utils.print_out(
                "  global step %d lr %g "
                "step-time %.2fs wps %.2fK ppl %.2f %s" %
                (global_step,
                 loaded_train_model.learning_rate.eval(session=train_sess),
                 avg_step_time, speed, train_ppl, _get_best_results(hparams)),
                log_f)

            if math.isnan(train_ppl):
                break

            # Reset timer and loss.
            step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0
            checkpoint_total_count = 0.0

        if global_step - last_eval_step >= steps_per_eval:
            last_eval_step = global_step

            utils.print_out("# Save eval, global step %d" % global_step)
            utils.add_summary(summary_writer, global_step, "train_ppl",
                              train_ppl)

            # Save checkpoint
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(out_dir,
                                                       "translate.ckpt"),
                                          global_step=global_step)

            # Evaluate on dev/test
            run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                              summary_writer, sample_src_data, sample_tgt_data)
            dev_ppl, test_ppl = run_internal_eval(eval_model, eval_sess,
                                                  model_dir, hparams,
                                                  summary_writer)

            dev_scores, test_scores, _ = run_external_eval(
                infer_model, infer_sess, model_dir, hparams, summary_writer)

        # if global_step - last_external_eval_step >= steps_per_external_eval:
        #   last_external_eval_step = global_step

        #   # Save checkpoint
        #   loaded_train_model.saver.save(
        #       train_sess,
        #       os.path.join(out_dir, "translate.ckpt"),
        #       global_step=global_step)
        #   run_sample_decode(infer_model, infer_sess,
        #                     model_dir, hparams, summary_writer, sample_src_data,
        #                     sample_tgt_data)
        #   dev_scores, test_scores, _ = run_external_eval(
        #       infer_model, infer_sess, model_dir,
        #       hparams, summary_writer)

    # Done training
    loaded_train_model.saver.save(train_sess,
                                  os.path.join(out_dir, "translate.ckpt"),
                                  global_step=global_step)

    result_summary, _, dev_scores, test_scores, dev_ppl, test_ppl = run_full_eval(
        model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams,
        summary_writer, sample_src_data, sample_tgt_data)

    utils.print_out(
        "# Final, step %d lr %g "
        "step-time %.2f wps %.2fK ppl %.2f, %s, %s" %
        (global_step,
         loaded_train_model.learning_rate.eval(session=train_sess),
         avg_step_time, speed, train_ppl, result_summary, time.ctime()), log_f)
    utils.print_time("# Done training!", start_train_time)

    utils.print_out("# Start evaluating saved best models.")
    for metric in hparams.metrics:
        best_model_dir = getattr(hparams, "best_" + metric + "_dir")
        result_summary, best_global_step, _, _, _, _ = run_full_eval(
            best_model_dir, infer_model, infer_sess, eval_model, eval_sess,
            hparams, summary_writer, sample_src_data, sample_tgt_data)
        utils.print_out(
            "# Best %s, step %d "
            "step-time %.2f wps %.2fK, %s, %s" %
            (metric, best_global_step, avg_step_time, speed, result_summary,
             time.ctime()), log_f)

    summary_writer.close()
    return (dev_scores, test_scores, dev_ppl, test_ppl, global_step)