Esempio n. 1
0
def process_stats(stats, info, global_step, steps_per_stats, log_f):
    """Update info and check for overflow."""
    # Per-step info
    info["avg_step_time"] = stats["step_time"] / steps_per_stats
    info["avg_grad_norm"] = stats["grad_norm"] / steps_per_stats
    info["avg_sequence_count"] = stats["sequence_count"] / steps_per_stats
    info["speed"] = stats["word_count"] / (1000 * stats["step_time"])

    # Per-predict info
    info["train_ppl"] = (utils.safe_exp(stats["train_loss"] /
                                        stats["predict_count"]))

    # Check for overflow
    is_overflow = False
    train_ppl = info["train_ppl"]
    if math.isnan(train_ppl) or math.isinf(train_ppl) or train_ppl > 1e20:
        utils.print_out("  step %d overflow, stop early" % global_step, log_f)
        is_overflow = True

    return is_overflow
Esempio n. 2
0
    def compute_perplexity(self, sess, name):
        """
        Compute perplexity of the output of the model.
        """

        total_loss = 0
        total_predict_count = 0

        while True:
            try:
                loss, predict_count, batch_size = self.eval(sess)
                total_loss += loss * batch_size
                total_predict_count += predict_count
            except tf.errors.OutOfRangeError:
                break

        perplexity = utils.safe_exp(total_loss / total_predict_count)
        utils.log("%s perplexity: %.2f" % (name, perplexity))

        return perplexity
Esempio n. 3
0
def check_stats(stats, global_step, steps_per_stats, hparams, log_f):
    """Print statistics and also check for overflow."""
    # Print statistics for the previous epoch.
    avg_step_time = stats["step_time"] / steps_per_stats
    avg_grad_norm = stats["grad_norm"] / steps_per_stats
    train_ppl = utils.safe_exp(stats["loss"] / stats["predict_count"])
    speed = stats["total_count"] / (1000 * stats["step_time"])
    utils.print_out(
        "  global step %d lr %g "
        "step-time %.2fs wps %.2fK ppl %.2f gN %.2f %s" %
        (global_step, stats["learning_rate"], avg_step_time, speed, train_ppl,
         avg_grad_norm, _get_best_results(hparams)), log_f)

    # Check for overflow
    is_overflow = False
    if math.isnan(train_ppl) or math.isinf(train_ppl) or train_ppl > 1e20:
        utils.print_out("  step %d overflow, stop early" % global_step, log_f)
        is_overflow = True

    return is_overflow
Esempio n. 4
0
def compute_perplexity(model, sess, name, eval_handle):
    """Compute perplexity of the output of the model based on loss function."""
    def aggregate_all_summaries(original, updates):
        for key in updates:
            if key not in original:
                original[key] = 0.0
            original[key] += updates[key]
        return original

    total_loss = 0
    total_predict_count = 0
    start_time = time.time()
    aggregated_summaries = {}
    batch_processed = 0
    while True:
        try:
            loss, all_summaries, predict_count, batch_size = model.eval(
                sess, eval_handle)
            total_loss += loss * batch_size
            batch_processed += 1
            total_predict_count += predict_count
            aggregated_summaries = aggregate_all_summaries(
                aggregated_summaries, all_summaries)
        except tf.errors.OutOfRangeError:
            break

    perplexity = utils.safe_exp(total_loss / total_predict_count)
    for key in aggregated_summaries:
        if key not in set([
                "eval_dialogue_loss1", "eval_dialogue_loss2",
                "eval_action_loss3"
        ]):
            aggregated_summaries[key] /= batch_processed
    utils.print_time("  eval %s: perplexity %.2f" % (name, perplexity),
                     start_time)
    return perplexity, aggregated_summaries
Esempio n. 5
0
def train(hparams, identity, scope=None, target_session=""):
  """main loop to train the dialogue model. identity is used."""
  out_dir = hparams.out_dir
  steps_per_stats = hparams.steps_per_stats
  steps_per_internal_eval = 3 * steps_per_stats

  model_creator = diag_model.Model

  train_model = model_helper.create_train_model(model_creator, hparams, scope)

  model_dir = hparams.out_dir

  # Log and output files
  log_file = os.path.join(out_dir, identity+"log_%d" % time.time())
  log_f = tf.gfile.GFile(log_file, mode="a")
  utils.print_out("# log_file=%s" % log_file, log_f)

  avg_step_time = 0.0

  # load TensorFlow session and model
  config_proto = utils.get_config_proto(
      log_device_placement=hparams.log_device_placement,
      allow_soft_placement=True)

  train_sess = tf.Session(
      target=target_session, config=config_proto, graph=train_model.graph)

  train_handle = train_sess.run(train_model.train_iterator.string_handle())

  with train_model.graph.as_default():
    loaded_train_model, global_step = model_helper.create_or_load_model(
        train_model.model, model_dir, train_sess, "train")

  # initialize summary writer
  summary_writer = tf.summary.FileWriter(
      os.path.join(out_dir, "train_log"), train_model.graph)

  last_stats_step = global_step
  last_eval_step = global_step

  # initialize training stats.
  step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0
  checkpoint_total_count = 0.0
  speed, train_ppl = 0.0, 0.0
  start_train_time = time.time()

  utils.print_out(
      "# Start step %d, lr %g, %s" %
      (global_step, loaded_train_model.learning_rate.eval(session=train_sess),
       time.ctime()),
      log_f)

  # initialize iterators
  skip_count = hparams.batch_size * hparams.epoch_step
  utils.print_out("# Init train iterator, skipping %d elements" % skip_count)
  train_sess.run(
      train_model.train_iterator.initializer,
      feed_dict={train_model.skip_count_placeholder: skip_count})

  # main training loop
  while global_step < hparams.num_train_steps:
    start_time = time.time()
    try:  #  run a step
      step_result = loaded_train_model.train(train_sess, train_handle)
      (_, step_loss, all_summaries, step_predict_count, step_summary,
       global_step, step_word_count, batch_size, _, _, words1, words2, mask1,
       mask2) = step_result
      hparams.epoch_step += 1

    except tf.errors.OutOfRangeError:  # finished an epoch
      hparams.epoch_step = 0
      utils.print_out("# Finished an epoch, step %d." % global_step)
      train_sess.run(
          train_model.train_iterator.initializer,
          feed_dict={train_model.skip_count_placeholder: 0})
      continue

    # Write step summary.
    summary_writer.add_summary(step_summary, global_step)
    for key in all_summaries:
      utils.add_summary(summary_writer, global_step, key, all_summaries[key])

    # update statistics
    step_time += (time.time() - start_time)

    checkpoint_loss += (step_loss * batch_size)
    checkpoint_predict_count += step_predict_count
    checkpoint_total_count += float(step_word_count)

    if global_step - last_stats_step >= steps_per_stats:
      # print statistics for the previous epoch and save the model.
      last_stats_step = global_step

      avg_step_time = step_time / steps_per_stats
      utils.add_summary(summary_writer, global_step, "step_time", avg_step_time)
      train_ppl = utils.safe_exp(checkpoint_loss / checkpoint_predict_count)
      speed = checkpoint_total_count / (1000 * step_time)
      if math.isnan(train_ppl):
        break

      # Reset timer and loss.
      step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0
      checkpoint_total_count = 0.0

      # save the model
      loaded_train_model.saver.save(
          train_sess,
          os.path.join(out_dir, "dialogue.ckpt"),
          global_step=global_step)

      # print the dialogue if in debug mode
      if hparams.debug:
        utils.print_current_dialogue(words1, words2, mask1, mask2)

    #  write out internal evaluation
    if global_step - last_eval_step >= steps_per_internal_eval:
      last_eval_step = global_step

      utils.print_out("# Internal Evaluation. global step %d" % global_step)
      utils.add_summary(summary_writer, global_step, "train_ppl", train_ppl)

  # finished training
  loaded_train_model.saver.save(
      train_sess,
      os.path.join(out_dir, "dialogue.ckpt"),
      global_step=global_step)
  result_summary = ""
  utils.print_out(
      "# Final, step %d lr %g "
      "step-time %.2f wps %.2fK ppl %.2f, %s, %s" %
      (global_step, loaded_train_model.learning_rate.eval(session=train_sess),
       avg_step_time, speed, train_ppl, result_summary, time.ctime()),
      log_f)
  utils.print_time("# Done training!", start_train_time)
  utils.print_out("# Start evaluating saved best models.")
  summary_writer.close()
Esempio n. 6
0
def train(hparams, scope=None, target_session=''):
    """Train the chatbot"""
    # Initialize some local hyperparameters
    log_device_placement = hparams.log_device_placement
    out_dir = hparams.out_dir
    num_train_steps = hparams.num_train_steps
    steps_per_stats = hparams.steps_per_stats
    steps_per_external_eval = hparams.steps_per_external_eval
    steps_per_eval = 10 * steps_per_stats
    if not steps_per_external_eval:
        steps_per_external_eval = 5 * steps_per_eval

    if hparams.architecture == "simple":
        model_creator = SimpleModel
        get_infer_iterator = iterator_utils.get_infer_iterator
        get_iterator = iterator_utils.get_iterator
    elif hparams.architecture == "hier":
        model_creator = HierarchicalModel
        # Parse some of the arguments now
        def curry_get_infer_iterator(dataset, vocab_table, batch_size, src_reverse,
                       eos, src_max_len):
            return end2end_iterator_utils.get_infer_iterator(dataset, vocab_table, batch_size, src_reverse, eos,
                                                      src_max_len=src_max_len, eou=hparams.eou,
                                                      dialogue_max_len=hparams.dialogue_max_len)
        get_infer_iterator = curry_get_infer_iterator

        def curry_get_iterator(src_dataset,
                 tgt_dataset,
                 vocab_table,
                 batch_size,
                 sos,
                 eos,
                 src_reverse,
                 random_seed,
                 num_buckets,
                 src_max_len=None,
                 tgt_max_len=None,
                 num_threads=4,
                 output_buffer_size=None,
                 skip_count=None):
            return end2end_iterator_utils.get_iterator(src_dataset, tgt_dataset, vocab_table, batch_size, sos, eos,
                                                eou=hparams.eou, src_reverse=src_reverse, random_seed=random_seed,
                                                num_dialogue_buckets=num_buckets, src_max_len=src_max_len,
                                                tgt_max_len=tgt_max_len, num_threads=num_threads,
                                                output_buffer_size=output_buffer_size, skip_count=skip_count)

        get_iterator = curry_get_iterator
    else:
        raise ValueError("Unkown architecture", hparams.architecture)

    # Create three models which share parameters through the use of checkpoints
    train_model = create_train_model(model_creator, get_iterator, hparams, scope)
    eval_model = create_eval_model(model_creator, get_iterator, hparams, scope)
    infer_model = inference.create_infer_model(model_creator, get_infer_iterator, hparams, scope)
    # ToDo: adapt for architectures
    # Preload the data to use for sample decoding

    dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src)
    dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt)
    sample_src_data = inference.load_data(dev_src_file)
    sample_tgt_data = inference.load_data(dev_tgt_file)

    summary_name = "train_log"
    model_dir = hparams.out_dir

    # Log and output files
    log_file = os.path.join(out_dir, "log_%d" % time.time())
    log_f = tf.gfile.GFile(log_file, mode="a")
    utils.print_out("# log_file=%s" % log_file, log_f)

    avg_step_time = 0.0

    # Create the configurations for the sessions
    config_proto = utils.get_config_proto(log_device_placement=log_device_placement)
    # Create three sessions, one for each model
    train_sess = tf.Session(target=target_session, config=config_proto, graph=train_model.graph)
    eval_sess = tf.Session(target=target_session, config=config_proto, graph=eval_model.graph)
    infer_sess = tf.Session(target=target_session, config=config_proto, graph=infer_model.graph)

    # Load the train model from checkpoint or create a new one
    with train_model.graph.as_default():
        loaded_train_model, global_step = model_helper.create_or_load_model(train_model.model, model_dir,
                                                                            train_sess, name="train")

    # Summary writer
    summary_writer = tf.summary.FileWriter(
        os.path.join(out_dir, summary_name), train_model.graph)
    # First evaluation
    run_full_eval(
        model_dir, infer_model, infer_sess,
        eval_model, eval_sess, hparams,
        summary_writer, sample_src_data,
        sample_tgt_data)

    last_stats_step = global_step
    last_eval_step = global_step
    last_external_eval_step = global_step

    # This is the training loop.
    # Initialize the hyperparameters for the loop.
    step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0
    checkpoint_total_count = 0.0
    speed, train_ppl = 0.0, 0.0
    start_train_time = time.time()

    utils.print_out(
        "# Start step %d, lr %g, %s" %
        (global_step, loaded_train_model.learning_rate.eval(session=train_sess),
         time.ctime()),
        log_f)

    # epoch_step records where we were within an epoch. Used to skip trained on examples
    skip_count = hparams.batch_size * hparams.epoch_step
    utils.print_out("# Init train iterator, skipping %d elements" % skip_count)
    # Initialize the training iterator
    train_sess.run(
        train_model.iterator.initializer,
        feed_dict={train_model.skip_count_placeholder: skip_count})

    # Train until we reach num_steps.
    while global_step < num_train_steps:
        # Run a step
        start_step_time = time.time()
        try:
            step_result = loaded_train_model.train(train_sess)
            (_, step_loss, step_predict_count, step_summary, global_step,  # The _ is the output of the update op
             step_word_count, batch_size) = step_result
            hparams.epoch_step += 1
        except tf.errors.OutOfRangeError:
            # Finished going through the training dataset.  Go to next epoch.
            hparams.epoch_step = 0
            utils.print_out(
                "# Finished an epoch, step %d. Perform external evaluation" %
                global_step)
            # Decode and print a random sentence
            run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer,
                              sample_src_data, sample_tgt_data)
            # Perform external evaluation to save checkpoints if this is the best for some metric
            dev_scores, test_scores, _ = run_external_evaluation(infer_model, infer_sess, model_dir, hparams,
                                                                 summary_writer, save_on_best_dev=True)
            # Reinitialize the iterator from the beginning
            train_sess.run(train_model.iterator.initializer,
                           feed_dict={train_model.skip_count_placeholder: 0})
            continue

        # Write step summary.
        summary_writer.add_summary(step_summary, global_step)

        # update statistics
        step_time += (time.time() - start_step_time)

        checkpoint_loss += (step_loss * batch_size)
        checkpoint_predict_count += step_predict_count
        checkpoint_total_count += float(step_word_count)

        # Once in a while, we print statistics.
        if global_step - last_stats_step >= steps_per_stats:
            last_stats_step = global_step

            # Print statistics for the previous epoch.
            avg_step_time = step_time / steps_per_stats
            train_ppl = utils.safe_exp(checkpoint_loss / checkpoint_predict_count)
            speed = checkpoint_total_count / (1000 * step_time)
            utils.print_out(
                "  global step %d lr %g "
                "step-time %.2fs wps %.2fK ppl %.2f %s" %
                (global_step,
                 loaded_train_model.learning_rate.eval(session=train_sess),
                 avg_step_time, speed, train_ppl, _get_best_results(hparams)),
                log_f)
            if math.isnan(train_ppl):
                # The model has screwed up
                break

            # Reset timer and loss.
            step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0
            checkpoint_total_count = 0.0

        if global_step - last_eval_step >= steps_per_eval:
            # Perform evaluation. Start by reassigning the last_eval_step variable to the current step
            last_eval_step = global_step
            # Print the progress and add summary
            utils.print_out("# Save eval, global step %d" % global_step)
            utils.add_summary(summary_writer, global_step, "train_ppl", train_ppl)

            # Save checkpoint
            loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "chatbot.ckpt"), global_step=global_step)
            # Decode and print a random sample
            run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer,
                              sample_src_data, sample_tgt_data)
            # Run internal evaluation, and update the ppl variables. The data iterator is instantieted in the method.
            dev_ppl, test_ppl = run_internal_eval(eval_model, eval_sess, model_dir, hparams, summary_writer)

        if global_step - last_external_eval_step >= steps_per_external_eval:
            # Run the external evaluation
            last_external_eval_step = global_step
            # Save checkpoint
            loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "chatbot.ckpt"), global_step=global_step)
            # Decode and print a random sample
            run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer,
                              sample_src_data, sample_tgt_data)
            # Run external evaluation, updating metric scores in the meanwhile. The unneeded output is the global step.
            dev_scores, test_scores, _ = run_external_evaluation(infer_model, infer_sess, model_dir, hparams,
                                                                 summary_writer, save_on_best_dev=True)

    # Done training. Save the model
    loaded_train_model.saver.save(
        train_sess,
        os.path.join(out_dir, "chatbot.ckpt"),
        global_step=global_step)

    result_summary, _, dev_scores, test_scores, dev_ppl, test_ppl = run_full_eval(
        model_dir, infer_model, infer_sess,
        eval_model, eval_sess, hparams,
        summary_writer, sample_src_data,
        sample_tgt_data)
    utils.print_out(
        "# Final, step %d lr %g "
        "step-time %.2f wps %.2fK ppl %.2f, %s, %s" %
        (global_step, loaded_train_model.learning_rate.eval(session=train_sess),
         avg_step_time, speed, train_ppl, result_summary, time.ctime()),
        log_f)
    utils.print_time("# Done training!", start_train_time)

    utils.print_out("# Start evaluating saved best models.")
    for metric in hparams.metrics:
        best_model_dir = getattr(hparams, "best_" + metric + "_dir")
        result_summary, best_global_step, _, _, _, _ = run_full_eval(
            best_model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams,
            summary_writer, sample_src_data, sample_tgt_data)
        utils.print_out("# Best %s, step %d "
                        "step-time %.2f wps %.2fK, %s, %s" %
                        (metric, best_global_step, avg_step_time, speed,
                         result_summary, time.ctime()), log_f)

    summary_writer.close()
    return (dev_scores, test_scores, dev_ppl, test_ppl, global_step)
Esempio n. 7
0
def train(hparams, scope=None, target_session="", single_cell_fn=None):
    """Train a translation model."""

    log_device_placement = hparams.log_device_placement
    out_dir = hparams.out_dir
    num_train_steps = hparams.num_train_steps
    steps_per_stats = hparams.steps_per_stats

    if hparams.eval_on_fly:
        steps_per_external_eval = hparams.steps_per_external_eval
        steps_per_eval = 10 * steps_per_stats

        if not steps_per_external_eval:
            steps_per_external_eval = 2 * steps_per_eval
    else:
        steps_per_snapshot = hparams.snapshot_interval

    if not hparams.attention:
        model_creator = nmt_model.Model
    elif hparams.attention_architecture == "standard":
        model_creator = attention_model.AttentionModel
    elif hparams.attention_architecture in ["gnmt", "gnmt_v2"]:
        model_creator = gnmt_model.GNMTModel
    else:
        raise ValueError("Unknown model architecture")

    train_model = create_train_model(model_creator, hparams, scope,
                                     single_cell_fn)

    if hparams.eval_on_fly:
        eval_model = create_eval_model(model_creator, hparams, scope,
                                       single_cell_fn)
        infer_model = inference.create_infer_model(model_creator, hparams,
                                                   scope, single_cell_fn)

        # Preload data for sample decoding.
        dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src)
        dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt)
        sample_src_data = inference.load_data(dev_src_file)
        sample_tgt_data = inference.load_data(dev_tgt_file)

    summary_name = "train_log"
    model_dir = hparams.out_dir

    # Log and output files
    log_file = os.path.join(out_dir, "log_%d" % time.time())
    log_f = tf.gfile.GFile(log_file, mode="a")
    utils.print_out("# log_file=%s" % log_file, log_f)

    avg_step_time = 0.0

    # TensorFlow model
    config_proto = utils.get_config_proto(
        log_device_placement=log_device_placement)

    train_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=train_model.graph)

    if hparams.eval_on_fly:
        eval_sess = tf.Session(target=target_session,
                               config=config_proto,
                               graph=eval_model.graph)
        infer_sess = tf.Session(target=target_session,
                                config=config_proto,
                                graph=infer_model.graph)

    with train_model.graph.as_default():
        model_helper.initialize_cnn(train_model.model, train_sess)
        loaded_train_model, global_step = model_helper.create_or_load_model(
            train_model.model, model_dir, train_sess, "train")

    # Summary writer
    summary_writer = tf.summary.FileWriter(os.path.join(out_dir, summary_name),
                                           train_model.graph)

    # First evaluation
    if hparams.eval_on_fly:
        run_full_eval(model_dir, infer_model, infer_sess, eval_model,
                      eval_sess, hparams, summary_writer, sample_src_data,
                      sample_tgt_data)

    last_stats_step = global_step

    if hparams.eval_on_fly:
        last_eval_step = global_step
        last_external_eval_step = global_step
    else:
        last_snapshot_step = global_step

    # This is the training loop.
    step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0
    checkpoint_total_count = 0.0
    speed, train_ppl = 0.0, 0.0
    start_train_time = time.time()

    utils.print_out(
        "# Start step %d, lr %g, %s" %
        (global_step, loaded_train_model.learning_rate.eval(
            session=train_sess), time.ctime()), log_f)

    # Initialize all of the iterators
    skip_count = hparams.batch_size * hparams.epoch_step
    utils.print_out("# Init train iterator, skipping %d elements" % skip_count)
    train_sess.run(train_model.iterator.initializer,
                   feed_dict={train_model.skip_count_placeholder: skip_count})

    while global_step < num_train_steps:
        ### Run a step ###
        start_time = time.time()
        try:
            step_result = loaded_train_model.train(train_sess)
            (_, step_loss, step_predict_count, step_summary, global_step,
             step_word_count, batch_size) = step_result
            hparams.epoch_step += 1
        except tf.errors.OutOfRangeError:
            # Finished going through the training dataset.  Go to next epoch.
            hparams.epoch_step = 0
            utils.print_out(
                "# Finished an epoch, step %d. Perform external evaluation" %
                global_step)

            if hparams.eval_on_fly:
                run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                                  summary_writer, sample_src_data,
                                  sample_tgt_data)
                dev_scores, test_scores, _ = run_external_eval(
                    infer_model, infer_sess, model_dir, hparams,
                    summary_writer)

            train_sess.run(train_model.iterator.initializer,
                           feed_dict={train_model.skip_count_placeholder: 0})
            continue

        # Write step summary.
        summary_writer.add_summary(step_summary, global_step)

        # update statistics
        step_time += (time.time() - start_time)

        checkpoint_loss += (step_loss * batch_size)
        checkpoint_predict_count += step_predict_count
        checkpoint_total_count += float(step_word_count)

        # Once in a while, we print statistics.
        if global_step - last_stats_step >= steps_per_stats:
            last_stats_step = global_step

            # Print statistics for the previous epoch.
            avg_step_time = step_time / steps_per_stats
            train_ppl = utils.safe_exp(checkpoint_loss /
                                       checkpoint_predict_count)
            speed = checkpoint_total_count / (1000 * step_time)
            utils.print_out(
                "  global step %d lr %g step-time %.2fs wps %.2fK ppl %.2f %s"
                %
                (global_step,
                 loaded_train_model.learning_rate.eval(session=train_sess),
                 avg_step_time, speed, train_ppl, _get_best_results(hparams)),
                log_f)
            if math.isnan(train_ppl):
                break

            # Reset timer and loss.
            step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0
            checkpoint_total_count = 0.0

        ##
        if (not hparams.eval_on_fly) and (global_step - last_snapshot_step >=
                                          steps_per_snapshot):
            last_snapshot_step = global_step
            utils.print_out("# Cihan: Saving Snapshot, global step %d" %
                            global_step)
            utils.add_summary(summary_writer, global_step, "train_ppl",
                              train_ppl)

            # Save checkpoint
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(out_dir,
                                                       "translate.ckpt"),
                                          global_step=global_step)

        if hparams.eval_on_fly and (global_step - last_eval_step >=
                                    steps_per_eval):
            last_eval_step = global_step

            utils.print_out("# Save eval, global step %d" % global_step)
            utils.add_summary(summary_writer, global_step, "train_ppl",
                              train_ppl)

            # Save checkpoint
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(out_dir,
                                                       "translate.ckpt"),
                                          global_step=global_step)

            # Evaluate on dev/test
            run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                              summary_writer, sample_src_data, sample_tgt_data)
            dev_ppl, test_ppl = run_internal_eval(eval_model, eval_sess,
                                                  model_dir, hparams,
                                                  summary_writer)

        if hparams.eval_on_fly and (global_step - last_external_eval_step >=
                                    steps_per_external_eval):
            last_external_eval_step = global_step

            # Save checkpoint
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(out_dir,
                                                       "translate.ckpt"),
                                          global_step=global_step)
            run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                              summary_writer, sample_src_data, sample_tgt_data)
            dev_scores, test_scores, _ = run_external_eval(
                infer_model, infer_sess, model_dir, hparams, summary_writer)

    # Done training
    loaded_train_model.saver.save(train_sess,
                                  os.path.join(out_dir, "translate.ckpt"),
                                  global_step=global_step)

    if hparams.eval_on_fly:
        result_summary, _, dev_scores, test_scores, dev_ppl, test_ppl = run_full_eval(
            model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams,
            summary_writer, sample_src_data, sample_tgt_data)

        utils.print_out(
            "# Final, step %d lr %g step-time %.2f wps %.2fK ppl %.2f, %s, %s"
            % (global_step,
               loaded_train_model.learning_rate.eval(session=train_sess),
               avg_step_time, speed, train_ppl, result_summary, time.ctime()),
            log_f)
        utils.print_time("# Done training!", start_train_time)

        utils.print_out("# Start evaluating saved best models.")
        for metric in hparams.metrics:
            best_model_dir = getattr(hparams, "best_" + metric + "_dir")
            result_summary, best_global_step, _, _, _, _ = run_full_eval(
                best_model_dir, infer_model, infer_sess, eval_model, eval_sess,
                hparams, summary_writer, sample_src_data, sample_tgt_data)
            utils.print_out(
                "# Best %s, step %d step-time %.2f wps %.2fK, %s, %s" %
                (metric, best_global_step, avg_step_time, speed,
                 result_summary, time.ctime()), log_f)

    summary_writer.close()

    if hparams.eval_on_fly:
        return dev_scores, test_scores, dev_ppl, test_ppl, global_step
    else:
        return global_step
Esempio n. 8
0
def train(hparams):
    train_model = mc.create_train_model(hparams)
    eval_model = mc.create_eval_model(hparams)
    infer_model = mc.create_infer_model(hparams)
    # TensorFlow model
    config_proto = utils.get_config_proto(
        log_device_placement=hparams.log_device_placement, allow_soft_placement=True)
    train_sess = tf.Session(graph=train_model.graph, config=config_proto)
    eval_sess = tf.Session(graph=eval_model.graph, config=config_proto)
    infer_sess = tf.Session(graph=infer_model.graph, config=config_proto)

    with train_model.graph.as_default():
        train_model, global_step = model_helper.create_or_load_model(train_model, hparams.out_dir, train_sess, "train")
    train_sess.run(train_model.iterator.initializer)

    step_in_epoch = 0
    current_epoch = 0
    chkpt_predicted_count, chkpt_train_loss = 0.0, 0.0

    start_time = time.time()
    while current_epoch < hparams.epoch:
        try:
            _, \
            train_loss, \
            step_predicted_count, \
            source, \
            target_input, \
            target_output, \
            logits, \
            final_context_state, \
            sample_id, \
            global_step = train_model.model.train(train_sess)
            step_in_epoch += 1
        except tf.errors.OutOfRangeError:
            step_in_epoch = 0
            print ("epoch %s finished" % str(current_epoch))

            infer_util.run_infer(hparams, infer_sess, infer_model, current_epoch, ["bleu", "rouge", "accuracy"])

            current_epoch = current_epoch + 1
            train_sess.run(train_model.iterator.initializer)
            continue

        # train_loss = step_results[0]
        # step_predicted_count = step_results[1]

        chkpt_predicted_count += step_predicted_count
        chkpt_train_loss += (train_loss * hparams.batch_size)

        if step_in_epoch % hparams.steps_per_eval == 0 and step_in_epoch > 0:
            # if hparams.time_major:
            #     source = np.transpose(source)
            #     target_input = np.transpose(target_input)
            #     target_output = np.transpose(target_output)
            print ("global step: ", str(global_step))
            # print (source[0])
            # print (target_input[0])
            # print (target_output[0])
            # print (np.shape(logits))
            # print (final_context_state)
            # print (sample_id)

            # user_input = input("please input")
            # print (user_input)
            train_ppl = utils.safe_exp(chkpt_train_loss / chkpt_predicted_count)
            print ("epoch %d: step_in_epoch %d, train ppl %.2f, avg loss %.2f, lr %.7f, time %.2fs"
                   % (current_epoch,
                      step_in_epoch,
                      train_ppl,
                      chkpt_train_loss/(hparams.steps_per_eval * hparams.batch_size),
                      train_model.model.learning_rate.eval(session=train_sess),
                      time.time() - start_time))
            start_time = time.time()
            chkpt_predicted_count, chkpt_train_loss = 0.0, 0.0

            train_model.saver.save(train_sess, os.path.join(hparams.out_dir, "translate.ckpt"), global_step=current_epoch)

            eval_ppl, eval_loss = eval_util.run_eval(hparams, eval_sess, eval_model, current_epoch, step_in_epoch, )