Example #1
0
def _decode_inference_indices(model, sess, output_infer,
                              output_infer_summary_prefix, inference_indices,
                              tgt_eos, subword_option):
    """Decoding only a specific set of sentences."""
    utils.print_out("  decoding to output %s , num sents %d." %
                    (output_infer, len(inference_indices)))
    start_time = time.time()
    with codecs.getwriter("utf-8")(tf.gfile.GFile(output_infer,
                                                  mode="wb")) as trans_f:
        trans_f.write("")  # Write empty string to ensure file is created.
        for decode_id in inference_indices:
            nmt_outputs, infer_summary = model.decode(sess)

            # get text translation
            assert nmt_outputs.shape[0] == 1
            translation = nmt_utils.get_translation(
                nmt_outputs,
                sent_id=0,
                tgt_eos=tgt_eos,
                subword_option=subword_option)

            if infer_summary is not None:  # Attention models
                image_file = output_infer_summary_prefix + str(
                    decode_id) + ".png"
                utils.print_out("  save attention image to %s*" % image_file)
                image_summ = tf.Summary()
                image_summ.ParseFromString(infer_summary)
                with tf.gfile.GFile(image_file, mode="w") as img_f:
                    img_f.write(image_summ.value[0].image.encoded_image_string)
            trans_f.write("%s\n" % translation)
            utils.print_out(translation + b"\n")
    utils.print_time("  done", start_time)
Example #2
0
def decode_and_evaluate(name,
                        model,
                        sess,
                        trans_file,
                        ref_file,
                        metrics,
                        subword_option,
                        beam_width,
                        tgt_eos,
                        num_translations_per_input=1,
                        decode=True):
    """Decode a test set and compute a score according to the evaluation task."""
    # Decode
    if decode:
        utils.print_out("  decoding to output %s." % trans_file)

        start_time = time.time()
        num_sentences = 0
        with codecs.getwriter("utf-8")(
                tf.gfile.GFile(trans_file, mode="wb")) as trans_f:
            trans_f.write("")  # Write empty string to ensure file is created.

            num_translations_per_input = max(
                min(num_translations_per_input, beam_width), 1)
            while True:
                try:
                  nmt_outputs, _ = model.decode(sess)
                  if beam_width == 0:
                      nmt_outputs = np.expand_dims(nmt_outputs, 0)

                  batch_size = nmt_outputs.shape[1]
                  num_sentences += batch_size

                  for sent_id in range(batch_size):
                      for beam_id in range(num_translations_per_input):
                          translation = get_translation(
                              nmt_outputs[beam_id],
                              sent_id,
                              tgt_eos=tgt_eos,
                              subword_option=subword_option)
                          trans_f.write((translation + b"\n").decode("utf-8"))
                except tf.errors.OutOfRangeError:
                    utils.print_time(
                        "  done, num sentences %d, num translations per input %d" %
                        (num_sentences, num_translations_per_input), start_time)
                    break

    # Evaluation
    evaluation_scores = {}
    if ref_file and tf.gfile.Exists(trans_file):
        for metric in metrics:
            score = evaluation_utils.evaluate(
                ref_file,
                trans_file,
                metric,
                subword_option=subword_option)
            evaluation_scores[metric] = score
            utils.print_out("  %s %s: %.1f" % (metric, name, score))

    return evaluation_scores
Example #3
0
def compute_perplexity(model, sess, name):
    """Compute perplexity of the output of the model.

  Args:
    model: model for compute perplexity.
    sess: tensorflow session to use.
    name: name of the batch.

  Returns:
    The perplexity of the eval outputs.
  """
    total_loss = 0
    total_predict_count = 0
    start_time = time.time()

    while True:
        try:
            loss, predict_count, batch_size = model.eval(sess)
            total_loss += loss * batch_size
            total_predict_count += predict_count
        except tf.errors.OutOfRangeError:
            break

    perplexity = utils.safe_exp(total_loss / total_predict_count)
    utils.print_time("  eval %s: perplexity %.2f" % (name, perplexity),
                     start_time)
    return perplexity
Example #4
0
def decode_and_evaluate(name,
                        model,
                        sess,
                        trans_file,
                        ref_file,
                        metrics,
                        bpe_delimiter,
                        beam_width,
                        tgt_eos,
                        decode=True):
  """Decode a test set and compute a score according to the evaluation task."""
  # Decode
  if decode:
    utils.print_out("  decoding to output %s." % trans_file)

    start_time = time.time()
    num_sentences = 0
    with codecs.getwriter("utf-8")(
        tf.gfile.GFile(trans_file, mode="wb")) as trans_f:
      trans_f.write("")  # Write empty string to ensure file is created.

      while True:
        try:
          nmt_outputs, _ = model.decode(sess)

          if beam_width > 0:
            # get the top translation.
            nmt_outputs = nmt_outputs[0]

          num_sentences += len(nmt_outputs)
          for sent_id in range(len(nmt_outputs)):
            translation = get_translation(
                nmt_outputs,
                sent_id,
                tgt_eos=tgt_eos,
                bpe_delimiter=bpe_delimiter)
            trans_f.write((translation + b"\n").decode("utf-8"))
        except tf.errors.OutOfRangeError:
          utils.print_time("  done, num sentences %d" % num_sentences,
                           start_time)
          break

  # Evaluation
  evaluation_scores = {}
  if ref_file and tf.gfile.Exists(trans_file):
    for metric in metrics:
      score = evaluation_utils.evaluate(
          ref_file,
          trans_file,
          metric,
          bpe_delimiter=bpe_delimiter)
      evaluation_scores[metric] = score
      utils.print_out("  %s %s: %.1f" % (metric, name, score))

  return evaluation_scores
def decode_and_evaluate(name,
                        model,
                        sess,
                        output_file,
                        reference_file,
                        metrics,
                        bpe_delimiter,
                        beam_width,
                        eos,
                        number_token=None,
                        name_token=None,
                        decode=True):
    """Decode a test set and compute a score according to the evaluation task."""
    # Decode
    if decode:
        utils.print_out("  decoding to output %s." % output_file)
        start_time = time.time()
        num_sentences = 0
        with tf.gfile.GFile(output_file, mode="w+") as out_f:
            out_f.write("")  # Write empty string to ensure file is created.

            while True:
                try:
                    # Get the response(s) for each input in the batch (whole file in this case)
                    # ToDo: adapt for architectures
                    outputs, infer_summary = model.decode(sess)

                    if beam_width > 0:
                        # Get the top response if we used beam_search
                        outputs = outputs[0]

                    num_sentences += len(outputs)
                    # Iterate over the outputs an write them to file
                    for sent_id in range(len(outputs)):
                        response = postprocess_output(outputs, sent_id, eos,
                                                      bpe_delimiter,
                                                      number_token, name_token)
                        out_f.write("%s\n" % response)
                except tf.errors.OutOfRangeError:
                    utils.print_time(
                        "  done, num sentences %d" % num_sentences, start_time)
                    break

    # Evaluation
    evaluation_scores = {}
    if reference_file and tf.gfile.Exists(output_file):
        for metric in metrics:
            score = evaluation_utils.evaluate(ref_file=reference_file,
                                              trans_file=output_file,
                                              metric=metric,
                                              bpe_delimiter=bpe_delimiter)
            evaluation_scores[metric] = score
            utils.print_out("  %s %s: %.1f" % (metric, name, score))

    return evaluation_scores
def compute_perplexity(model, sess, name):
    total_loss = 0
    total_predict_count = 0
    start_time = time.time()

    while True:
        try:
            loss, predict_count, batch_size = model.eval(sess)
            total_loss += loss * batch_size
            total_predict_count += predict_count
        except tf.errors.OutOfRangeError:
            break

    perplexity = utils.safe_exp(total_loss / total_predict_count)
    utils.print_time("  eval %s: perplexity %.2f" % (name, perplexity),
                     start_time)
    return perplexity
def _decode_inference_indices(model, sess,
                              output_infer_file,
                              output_infer_summary_prefix,
                              inference_indices,
                              eos,
                              bpe_delimiter,
                              number_token=None,
                              name_token=None):
    """
    Decoding only a specific set of sentences indicated by inference_indices
    :param output_infer:
    :param output_infer_summary_prefix:
    :param inference_indices: A list of sentence indices
    :param eos: the eos token
    :param bpe_delimiter: delimiter used for byte-pair entries
    :return:
    """
    utils.print_out("  decoding to output %s , num sents %d." %
                    (output_infer_file, len(inference_indices)))
    start_time = time.time()
    with codecs.getwriter("utf-8")(tf.gfile.GFile(output_infer_file, 'wb')) as f:
        f.write("")  # Write empty string to ensure that the file is created
        # Get the outputs
        outputs, infer_summary = model.decode(sess)

        # Iterate over the sentences we want to process. Use the index to process sentences and the
        # decode_id to create logs
        for sentence_id, decode_id in enumerate(inference_indices):
            # Get the response
            response = chatbot_utils.postprocess_output(outputs, sentence_id=sentence_id, eos=eos,
                                                        bpe_delimiter=bpe_delimiter, number_token=number_token,
                                                        name_token=name_token)
            # TODO: add inference_summary if deciding to use attention

            # Write the response to file
            f.write("%s\n" % response)
            utils.print_out("%s\n" % response)
    utils.print_time("  done", start_time)
Example #8
0
def compute_perplexity(model, sess, name, eval_handle):
    """Compute perplexity of the output of the model based on loss function."""
    def aggregate_all_summaries(original, updates):
        for key in updates:
            if key not in original:
                original[key] = 0.0
            original[key] += updates[key]
        return original

    total_loss = 0
    total_predict_count = 0
    start_time = time.time()
    aggregated_summaries = {}
    batch_processed = 0
    while True:
        try:
            loss, all_summaries, predict_count, batch_size = model.eval(
                sess, eval_handle)
            total_loss += loss * batch_size
            batch_processed += 1
            total_predict_count += predict_count
            aggregated_summaries = aggregate_all_summaries(
                aggregated_summaries, all_summaries)
        except tf.errors.OutOfRangeError:
            break

    perplexity = utils.safe_exp(total_loss / total_predict_count)
    for key in aggregated_summaries:
        if key not in set([
                "eval_dialogue_loss1", "eval_dialogue_loss2",
                "eval_action_loss3"
        ]):
            aggregated_summaries[key] /= batch_processed
    utils.print_time("  eval %s: perplexity %.2f" % (name, perplexity),
                     start_time)
    return perplexity, aggregated_summaries
Example #9
0
def train(hparams, scope=None, target_session=""):
    """Train a translation model."""
    log_device_placement = hparams.log_device_placement
    out_dir = hparams.out_dir
    num_train_steps = hparams.num_train_steps
    steps_per_stats = hparams.steps_per_stats
    steps_per_external_eval = hparams.steps_per_external_eval
    steps_per_eval = 10 * steps_per_stats
    avg_ckpts = hparams.avg_ckpts

    if not steps_per_external_eval:
        steps_per_external_eval = 5 * steps_per_eval

    # Create model
    model_creator = get_model_creator(hparams)
    train_model = model_helper.create_train_model(model_creator, hparams,
                                                  scope)
    eval_model = model_helper.create_eval_model(model_creator, hparams, scope)
    infer_model = model_helper.create_infer_model(model_creator, hparams,
                                                  scope)

    # Preload data for sample decoding.
    dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src)
    dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt)
    sample_src_data = inference.load_data(dev_src_file)
    sample_tgt_data = inference.load_data(dev_tgt_file)

    summary_name = "train_log"
    model_dir = hparams.out_dir

    # Log and output files
    log_file = os.path.join(out_dir, "log_%d" % time.time())
    log_f = tf.gfile.GFile(log_file, mode="a")
    utils.print_out("# log_file=%s" % log_file, log_f)

    # TensorFlow model
    config_proto = utils.get_config_proto(
        log_device_placement=log_device_placement,
        num_intra_threads=hparams.num_intra_threads,
        num_inter_threads=hparams.num_inter_threads)
    train_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=train_model.graph)
    eval_sess = tf.Session(target=target_session,
                           config=config_proto,
                           graph=eval_model.graph)
    infer_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=infer_model.graph)

    with train_model.graph.as_default():
        loaded_train_model, global_step = model_helper.create_or_load_model(
            train_model.model, model_dir, train_sess, "train")

    # Summary writer
    summary_writer = tf.summary.FileWriter(os.path.join(out_dir, summary_name),
                                           train_model.graph)

    # First evaluation
    run_full_eval(model_dir, infer_model, infer_sess, eval_model, eval_sess,
                  hparams, summary_writer, sample_src_data, sample_tgt_data,
                  avg_ckpts)

    last_stats_step = global_step
    last_eval_step = global_step
    last_external_eval_step = global_step

    # This is the training loop.
    stats, info, start_train_time = before_train(loaded_train_model,
                                                 train_model, train_sess,
                                                 global_step, hparams, log_f)
    while global_step < num_train_steps:
        ### Run a step ###
        start_time = time.time()
        try:
            step_result = loaded_train_model.train(train_sess)
            hparams.epoch_step += 1
        except tf.errors.OutOfRangeError:
            # Finished going through the training dataset.  Go to next epoch.
            hparams.epoch_step = 0
            utils.print_out(
                "# Finished an epoch, step %d. Perform external evaluation" %
                global_step)
            run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                              summary_writer, sample_src_data, sample_tgt_data)
            run_external_eval(infer_model, infer_sess, model_dir, hparams,
                              summary_writer)

            if avg_ckpts:
                run_avg_external_eval(infer_model, infer_sess, model_dir,
                                      hparams, summary_writer, global_step)

            train_sess.run(train_model.iterator.initializer,
                           feed_dict={train_model.skip_count_placeholder: 0})
            continue

        # Process step_result, accumulate stats, and write summary
        global_step, info["learning_rate"], step_summary = update_stats(
            stats, start_time, step_result)
        summary_writer.add_summary(step_summary, global_step)

        # Once in a while, we print statistics.
        if global_step - last_stats_step >= steps_per_stats:
            last_stats_step = global_step
            is_overflow = process_stats(stats, info, global_step,
                                        steps_per_stats, log_f)
            print_step_info("  ", global_step, info, get_best_results(hparams),
                            log_f)
            if is_overflow:
                break

            # Reset statistics
            stats = init_stats()

        if global_step - last_eval_step >= steps_per_eval:
            last_eval_step = global_step
            utils.print_out("# Save eval, global step %d" % global_step)
            add_info_summaries(summary_writer, global_step, info)

            # Save checkpoint
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(out_dir,
                                                       "translate.ckpt"),
                                          global_step=global_step)

            # Evaluate on dev/test
            run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                              summary_writer, sample_src_data, sample_tgt_data)
            run_internal_eval(eval_model, eval_sess, model_dir, hparams,
                              summary_writer)

        if global_step - last_external_eval_step >= steps_per_external_eval:
            last_external_eval_step = global_step

            # Save checkpoint
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(out_dir,
                                                       "translate.ckpt"),
                                          global_step=global_step)
            run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                              summary_writer, sample_src_data, sample_tgt_data)
            run_external_eval(infer_model, infer_sess, model_dir, hparams,
                              summary_writer)

            if avg_ckpts:
                run_avg_external_eval(infer_model, infer_sess, model_dir,
                                      hparams, summary_writer, global_step)

    # Done training
    loaded_train_model.saver.save(train_sess,
                                  os.path.join(out_dir, "translate.ckpt"),
                                  global_step=global_step)

    (result_summary, _, final_eval_metrics) = (run_full_eval(
        model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams,
        summary_writer, sample_src_data, sample_tgt_data, avg_ckpts))
    print_step_info("# Final, ", global_step, info, result_summary, log_f)
    utils.print_time("# Done training!", start_train_time)

    summary_writer.close()

    utils.print_out("# Start evaluating saved best models.")
    for metric in hparams.metrics:
        best_model_dir = getattr(hparams, "best_" + metric + "_dir")
        summary_writer = tf.summary.FileWriter(
            os.path.join(best_model_dir, summary_name), infer_model.graph)
        result_summary, best_global_step, _ = run_full_eval(
            best_model_dir, infer_model, infer_sess, eval_model, eval_sess,
            hparams, summary_writer, sample_src_data, sample_tgt_data)
        print_step_info("# Best %s, " % metric, best_global_step, info,
                        result_summary, log_f)
        summary_writer.close()

        if avg_ckpts:
            best_model_dir = getattr(hparams, "avg_best_" + metric + "_dir")
            summary_writer = tf.summary.FileWriter(
                os.path.join(best_model_dir, summary_name), infer_model.graph)
            result_summary, best_global_step, _ = run_full_eval(
                best_model_dir, infer_model, infer_sess, eval_model, eval_sess,
                hparams, summary_writer, sample_src_data, sample_tgt_data)
            print_step_info("# Averaged Best %s, " % metric, best_global_step,
                            info, result_summary, log_f)
            summary_writer.close()

    return final_eval_metrics, global_step
Example #10
0
  loaded_train_model.saver.save(
      train_sess,
      os.path.join(out_dir, "summarized.ckpt"),
      global_step=global_step)

  result_summary, _, dev_scores, test_scores, dev_ppl, test_ppl = run_full_eval(
      model_dir, infer_model, infer_sess,
      eval_model, eval_sess, hps,
      summary_writer,sample_src_data,sample_tgt_data)
  utils.print_out(
      "# Final, step %d lr %g "
      "step-time %.2f wps %.2fK ppl %.2f, %s, %s" %
      (global_step, loaded_train_model.learning_rate.eval(session=train_sess),
       avg_step_time, speed, train_ppl, result_summary, time.ctime()),
      log_f)
  utils.print_time("# Done training!", start_train_time)

  summary_writer.close()

  utils.print_out("# Start evaluating saved best models.")
  for metric in hps.metrics:
    best_model_dir = getattr(hps, "best_" + metric + "_dir")
    summary_writer = tf.summary.FileWriter(
        os.path.join(best_model_dir, summary_name), infer_model.graph)
    result_summary, best_global_step, _, _, _, _ = run_full_eval(
        best_model_dir, infer_model, infer_sess, eval_model, eval_sess, hps,
        summary_writer,sample_src_data,sample_tgt_data)
    utils.print_out("# Best %s, step %d "
                    "step-time %.2f wps %.2fK, %s, %s" %
                    (metric, best_global_step, avg_step_time, speed,
                     result_summary, time.ctime()), log_f)
Example #11
0
def train(hparams, identity, scope=None, target_session=""):
  """main loop to train the dialogue model. identity is used."""
  out_dir = hparams.out_dir
  steps_per_stats = hparams.steps_per_stats
  steps_per_internal_eval = 3 * steps_per_stats

  model_creator = diag_model.Model

  train_model = model_helper.create_train_model(model_creator, hparams, scope)

  model_dir = hparams.out_dir

  # Log and output files
  log_file = os.path.join(out_dir, identity+"log_%d" % time.time())
  log_f = tf.gfile.GFile(log_file, mode="a")
  utils.print_out("# log_file=%s" % log_file, log_f)

  avg_step_time = 0.0

  # load TensorFlow session and model
  config_proto = utils.get_config_proto(
      log_device_placement=hparams.log_device_placement,
      allow_soft_placement=True)

  train_sess = tf.Session(
      target=target_session, config=config_proto, graph=train_model.graph)

  train_handle = train_sess.run(train_model.train_iterator.string_handle())

  with train_model.graph.as_default():
    loaded_train_model, global_step = model_helper.create_or_load_model(
        train_model.model, model_dir, train_sess, "train")

  # initialize summary writer
  summary_writer = tf.summary.FileWriter(
      os.path.join(out_dir, "train_log"), train_model.graph)

  last_stats_step = global_step
  last_eval_step = global_step

  # initialize training stats.
  step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0
  checkpoint_total_count = 0.0
  speed, train_ppl = 0.0, 0.0
  start_train_time = time.time()

  utils.print_out(
      "# Start step %d, lr %g, %s" %
      (global_step, loaded_train_model.learning_rate.eval(session=train_sess),
       time.ctime()),
      log_f)

  # initialize iterators
  skip_count = hparams.batch_size * hparams.epoch_step
  utils.print_out("# Init train iterator, skipping %d elements" % skip_count)
  train_sess.run(
      train_model.train_iterator.initializer,
      feed_dict={train_model.skip_count_placeholder: skip_count})

  # main training loop
  while global_step < hparams.num_train_steps:
    start_time = time.time()
    try:  #  run a step
      step_result = loaded_train_model.train(train_sess, train_handle)
      (_, step_loss, all_summaries, step_predict_count, step_summary,
       global_step, step_word_count, batch_size, _, _, words1, words2, mask1,
       mask2) = step_result
      hparams.epoch_step += 1

    except tf.errors.OutOfRangeError:  # finished an epoch
      hparams.epoch_step = 0
      utils.print_out("# Finished an epoch, step %d." % global_step)
      train_sess.run(
          train_model.train_iterator.initializer,
          feed_dict={train_model.skip_count_placeholder: 0})
      continue

    # Write step summary.
    summary_writer.add_summary(step_summary, global_step)
    for key in all_summaries:
      utils.add_summary(summary_writer, global_step, key, all_summaries[key])

    # update statistics
    step_time += (time.time() - start_time)

    checkpoint_loss += (step_loss * batch_size)
    checkpoint_predict_count += step_predict_count
    checkpoint_total_count += float(step_word_count)

    if global_step - last_stats_step >= steps_per_stats:
      # print statistics for the previous epoch and save the model.
      last_stats_step = global_step

      avg_step_time = step_time / steps_per_stats
      utils.add_summary(summary_writer, global_step, "step_time", avg_step_time)
      train_ppl = utils.safe_exp(checkpoint_loss / checkpoint_predict_count)
      speed = checkpoint_total_count / (1000 * step_time)
      if math.isnan(train_ppl):
        break

      # Reset timer and loss.
      step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0
      checkpoint_total_count = 0.0

      # save the model
      loaded_train_model.saver.save(
          train_sess,
          os.path.join(out_dir, "dialogue.ckpt"),
          global_step=global_step)

      # print the dialogue if in debug mode
      if hparams.debug:
        utils.print_current_dialogue(words1, words2, mask1, mask2)

    #  write out internal evaluation
    if global_step - last_eval_step >= steps_per_internal_eval:
      last_eval_step = global_step

      utils.print_out("# Internal Evaluation. global step %d" % global_step)
      utils.add_summary(summary_writer, global_step, "train_ppl", train_ppl)

  # finished training
  loaded_train_model.saver.save(
      train_sess,
      os.path.join(out_dir, "dialogue.ckpt"),
      global_step=global_step)
  result_summary = ""
  utils.print_out(
      "# Final, step %d lr %g "
      "step-time %.2f wps %.2fK ppl %.2f, %s, %s" %
      (global_step, loaded_train_model.learning_rate.eval(session=train_sess),
       avg_step_time, speed, train_ppl, result_summary, time.ctime()),
      log_f)
  utils.print_time("# Done training!", start_train_time)
  utils.print_out("# Start evaluating saved best models.")
  summary_writer.close()
Example #12
0
def decode_and_evaluate(run,
                        iterations,
                        name,
                        model,
                        sess,
                        trans_file,
                        ref_file,
                        metrics,
                        subword_option,
                        beam_width,
                        tgt_eos,
                        num_translations_per_input=1,
                        decode=True,
                        infer_mode="greedy"):
    """Decode a test set and compute a score according to the evaluation task."""
    # Decode
    if decode:
        utils.print_out("  decoding to output %s" % trans_file)

        num_sentences = 0
        with codecs.getwriter("utf-8")(tf.gfile.GFile(trans_file,
                                                      mode="wb")) as trans_f:
            trans_f.write("")  # Write empty string to ensure file is created.

            if infer_mode == "greedy":
                num_translations_per_input = 1
            elif infer_mode == "beam_search":
                num_translations_per_input = min(num_translations_per_input,
                                                 beam_width)

            print("  infer_mode %s, beam_width %g, num translations per input %d. " \
                  % (infer_mode, beam_width, num_translations_per_input))
            print("  total iterations count %d." % iterations)

            # prediction time is the time for the model prediction only
            # overall time is the time for data pre-processing and data post-processing
            prediction_times = list()
            overall_start = time.time()

            n = 0
            while n < iterations:
                n += 1
                while True:
                    try:
                        start = time.time()
                        nmt_outputs, _ = model.decode(sess)
                        prediction_times.append(time.time() - start)
                        if infer_mode != "beam_search":
                            nmt_outputs = np.expand_dims(nmt_outputs, 0)

                        batch_size = nmt_outputs.shape[1]

                        num_sentences += batch_size
                        for sent_id in range(batch_size):
                            for beam_id in range(num_translations_per_input):
                                translation = get_translation(
                                    nmt_outputs[beam_id],
                                    sent_id,
                                    tgt_eos=tgt_eos,
                                    subword_option=subword_option)
                                if run == 'accuracy':
                                    trans_f.write(
                                        (translation + b"\n").decode("utf-8"))

                    except tf.errors.OutOfRangeError:
                        utils.print_time(
                            "  done, num sentences %d, num translations per input %d"
                            % (num_sentences, num_translations_per_input),
                            overall_start)
                        break

        overall_time = (time.time() - overall_start)
        if run == 'performance':
            print("\nAverage Prediction Latency: {:.5f} sec per batch.".format(
                sum(prediction_times) / float(len(prediction_times))))
            print("Overall Latency: {:.5f} sec for the entire test "
                  "dataset.".format(overall_time / float(iterations)))
            print("Overall Throughput : {:.3f} sentences per sec.".format(
                num_sentences / float(overall_time)))

    # Evaluation
    evaluation_scores = {}
    if ref_file and tf.gfile.Exists(trans_file):
        for metric in metrics:
            score = evaluation_utils.evaluate(ref_file,
                                              trans_file,
                                              metric,
                                              subword_option=subword_option)
            evaluation_scores[metric] = score
            utils.print_out("  %s %s: %.1f" % (metric, name, score))

    return evaluation_scores
Example #13
0
def train(hps, scope=None, target_session=""):
    """Train a translation model."""
    log_device_placement = hps.log_device_placement
    out_dir = hps.out_dir
    num_train_steps = hps.num_train_steps
    steps_per_stats = hps.steps_per_stats
    steps_per_external_eval = hps.steps_per_external_eval
    steps_per_eval = 100 * steps_per_stats
    if not steps_per_external_eval:
        steps_per_external_eval = 5 * steps_per_eval

    if hps.attention_architecture == "baseline":
        model_creator = AttentionModel
    else:
        model_creator = AttentionHistoryModel

    train_model = model_helper.create_train_model(model_creator, hps, scope)
    eval_model = model_helper.create_eval_model(model_creator, hps, scope)
    infer_model = model_helper.create_infer_model(model_creator, hps, scope)

    # Preload data for sample decoding.

    article_filenames = []
    abstract_filenames = []
    art_dir = hps.data_dir + '/article'
    abs_dir = hps.data_dir + '/abstract'
    for file in os.listdir(art_dir):
        if file.startswith(hps.dev_prefix):
            article_filenames.append(art_dir + "/" + file)
    for file in os.listdir(abs_dir):
        if file.startswith(hps.dev_prefix):
            abstract_filenames.append(abs_dir + "/" + file)
    # if random_decode:
    #     """if this is a random sampling process during training"""
    decode_id = random.randint(0, len(article_filenames) - 1)
    single_article_file = article_filenames[decode_id]
    single_abstract_file = abstract_filenames[decode_id]

    dev_src_file = single_article_file
    dev_tgt_file = single_abstract_file
    sample_src_data = inference_base_model.load_data(dev_src_file)
    sample_tgt_data = inference_base_model.load_data(dev_tgt_file)

    summary_name = "train_log"
    model_dir = hps.out_dir

    # Log and output files
    log_file = os.path.join(out_dir, "log_%d" % time.time())
    log_f = tf.gfile.GFile(log_file, mode="a")
    utils.print_out("# log_file=%s" % log_file, log_f)

    avg_step_time = 0.0

    # TensorFlow model
    config_proto = utils.get_config_proto(
        log_device_placement=log_device_placement)

    train_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=train_model.graph)
    eval_sess = tf.Session(target=target_session,
                           config=config_proto,
                           graph=eval_model.graph)
    infer_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=infer_model.graph)

    with train_model.graph.as_default():
        loaded_train_model, global_step = model_helper.create_or_load_model(
            train_model.model, model_dir, train_sess, "train")

    # Summary writer
    summary_writer = tf.summary.FileWriter(os.path.join(out_dir, summary_name),
                                           train_model.graph)

    # First evaluation
    # run_full_eval(
    #     model_dir, infer_model, infer_sess,
    #     eval_model, eval_sess, hps,
    #     summary_writer,sample_src_data,sample_tgt_data)

    last_stats_step = global_step
    last_eval_step = global_step
    last_external_eval_step = global_step

    # This is the training loop.
    stats = init_stats()
    speed, train_ppl = 0.0, 0.0
    start_train_time = time.time()

    utils.print_out(
        "# Start step %d, lr %g, %s" %
        (global_step, loaded_train_model.learning_rate.eval(
            session=train_sess), time.ctime()), log_f)

    # Initialize all of the iterators
    skip_count = hps.batch_size * hps.epoch_step
    utils.print_out("# Init train iterator, skipping %d elements" % skip_count)
    train_sess.run(train_model.iterator.initializer,
                   feed_dict={train_model.skip_count_placeholder: skip_count})
    epoch_step = 0
    while global_step < num_train_steps:
        ### Run a step ###
        start_time = time.time()
        try:
            step_result = loaded_train_model.train(train_sess)
            epoch_step += 1
        except tf.errors.OutOfRangeError:
            # Finished going through the training dataset.  Go to next epoch.
            epoch_step = 0
            utils.print_out(
                "# Finished an epoch, step %d. Perform external evaluation" %
                global_step)

            run_sample_decode(infer_model, infer_sess, model_dir, hps,
                              summary_writer, sample_src_data, sample_tgt_data)
            dev_scores, test_scores, _ = run_external_eval(
                infer_model, infer_sess, model_dir, hps, summary_writer)
            train_sess.run(train_model.iterator.initializer,
                           feed_dict={train_model.skip_count_placeholder: 0})
            continue

        # Write step summary and accumulate statistics
        global_step = update_stats(stats, summary_writer, start_time,
                                   step_result)

        # Once in a while, we print statistics.
        if global_step - last_stats_step >= steps_per_stats:
            last_stats_step = global_step
            is_overflow = check_stats(stats, global_step, steps_per_stats, hps,
                                      log_f)
            if is_overflow:
                break

            # Reset statistics
            stats = init_stats()

        if global_step - last_eval_step >= steps_per_eval:
            last_eval_step = global_step

            utils.print_out("# Save eval, global step %d" % global_step)
            utils.add_summary(summary_writer, global_step, "train_ppl",
                              train_ppl)

            # Save checkpoint
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(out_dir,
                                                       "summarized.ckpt"),
                                          global_step=global_step)

            # Evaluate on dev/test
            run_sample_decode(infer_model, infer_sess, model_dir, hps,
                              summary_writer, sample_src_data, sample_tgt_data)
            dev_ppl, test_ppl = run_internal_eval(eval_model, eval_sess,
                                                  model_dir, hps,
                                                  summary_writer)

        if global_step - last_external_eval_step >= steps_per_external_eval:
            last_external_eval_step = global_step

            # Save checkpoint
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(out_dir,
                                                       "summarized.ckpt"),
                                          global_step=global_step)
            run_sample_decode(infer_model, infer_sess, model_dir, hps,
                              summary_writer, sample_src_data, sample_tgt_data)
            dev_scores, test_scores, _ = run_external_eval(
                infer_model, infer_sess, model_dir, hps, summary_writer)

    # Done training
    loaded_train_model.saver.save(train_sess,
                                  os.path.join(out_dir, "summarized.ckpt"),
                                  global_step=global_step)

    result_summary, _, dev_scores, test_scores, dev_ppl, test_ppl = run_full_eval(
        model_dir, infer_model, infer_sess, eval_model, eval_sess, hps,
        summary_writer, sample_src_data, sample_tgt_data)
    utils.print_out(
        "# Final, step %d lr %g "
        "step-time %.2f wps %.2fK ppl %.2f, %s, %s" %
        (global_step,
         loaded_train_model.learning_rate.eval(session=train_sess),
         avg_step_time, speed, train_ppl, result_summary, time.ctime()), log_f)
    utils.print_time("# Done training!", start_train_time)

    summary_writer.close()

    utils.print_out("# Start evaluating saved best models.")
    for metric in hps.metrics:
        best_model_dir = getattr(hps, "best_" + metric + "_dir")
        summary_writer = tf.summary.FileWriter(
            os.path.join(best_model_dir, summary_name), infer_model.graph)
        result_summary, best_global_step, _, _, _, _ = run_full_eval(
            best_model_dir, infer_model, infer_sess, eval_model, eval_sess,
            hps, summary_writer, sample_src_data, sample_tgt_data)
        utils.print_out(
            "# Best %s, step %d "
            "step-time %.2f wps %.2fK, %s, %s" %
            (metric, best_global_step, avg_step_time, speed, result_summary,
             time.ctime()), log_f)
        summary_writer.close()

    return (dev_scores, test_scores, dev_ppl, test_ppl, global_step)
Example #14
0
    def infer(self, num_print_per_batch=0):
        model_dir = self.model_dir
        out_dir = self.out_dir
        dev_src_file = self.dev_src_file
        dev_tgt_file = self.dev_tgt_file
        infer_batch_size = self.hparams.infer_batch_size
        beam_width = self.hparams.beam_width
        infer_model = self.infer_model
        infer_sess = self.infer_sess

        infer_output_file = os.path.join(out_dir, 'infer_output')

        start_time = time.time()
        print('# Decoding to %s' % infer_output_file)

        # Load infer model
        with infer_model.graph.as_default():
            loaded_infer_model, global_step = model_helper.create_or_load_model(
                infer_model.model, model_dir, infer_sess, "infer")

        with open(dev_src_file, encoding='utf-8') as in_src_file, \
                open(dev_tgt_file, encoding='utf-8') as in_tgt_file, \
                open(infer_output_file, mode='w', encoding='utf-8') as out_file:
            infer_src_data = in_src_file.readlines()
            infer_tgt_data = in_tgt_file.readlines()

            iterator_feed_dict = {
                infer_model.src_data_placeholder: infer_src_data,
                infer_model.batch_size_placeholder: infer_batch_size
            }
            infer_sess.run(
                infer_model.iterator.initializer,
                feed_dict=iterator_feed_dict)

            num_sentences = 0
            while True:
                try:
                    # The shape of sample_words is [batch_size, time] or
                    # [beam_width, batch_size, time] when using beam search.
                    sample_words = loaded_infer_model.decode(infer_sess)

                    if beam_width == 0:
                        sample_words = np.expand_dims(sample_words, 0)

                    batch_size = sample_words.shape[1]

                    for sent_id in range(batch_size):
                        beam_id = random.randint(0, beam_width - 1) if beam_width > 0 else 0
                        response = self._get_response(sample_words[beam_id][sent_id])
                        out_file.write(response + '\n')

                        if sent_id < num_print_per_batch:
                            sent_id += num_sentences
                            print("  sentence %d" % sent_id)
                            print("  src: %s" % infer_src_data[sent_id], end='')
                            print("  ref: %s" % infer_tgt_data[sent_id], end='')
                            print("  bot: %s" % response)

                    num_sentences += batch_size
                except tf.errors.OutOfRangeError:
                    utils.print_time(
                        "  done, num sentences %d, beam width %d" %
                        (num_sentences, beam_width), start_time)
                    break
Example #15
0
def train(hparams, scope=None, target_session=''):
    """Train the chatbot"""
    # Initialize some local hyperparameters
    log_device_placement = hparams.log_device_placement
    out_dir = hparams.out_dir
    num_train_steps = hparams.num_train_steps
    steps_per_stats = hparams.steps_per_stats
    steps_per_external_eval = hparams.steps_per_external_eval
    steps_per_eval = 10 * steps_per_stats
    if not steps_per_external_eval:
        steps_per_external_eval = 5 * steps_per_eval

    if hparams.architecture == "simple":
        model_creator = SimpleModel
        get_infer_iterator = iterator_utils.get_infer_iterator
        get_iterator = iterator_utils.get_iterator
    elif hparams.architecture == "hier":
        model_creator = HierarchicalModel
        # Parse some of the arguments now
        def curry_get_infer_iterator(dataset, vocab_table, batch_size, src_reverse,
                       eos, src_max_len):
            return end2end_iterator_utils.get_infer_iterator(dataset, vocab_table, batch_size, src_reverse, eos,
                                                      src_max_len=src_max_len, eou=hparams.eou,
                                                      dialogue_max_len=hparams.dialogue_max_len)
        get_infer_iterator = curry_get_infer_iterator

        def curry_get_iterator(src_dataset,
                 tgt_dataset,
                 vocab_table,
                 batch_size,
                 sos,
                 eos,
                 src_reverse,
                 random_seed,
                 num_buckets,
                 src_max_len=None,
                 tgt_max_len=None,
                 num_threads=4,
                 output_buffer_size=None,
                 skip_count=None):
            return end2end_iterator_utils.get_iterator(src_dataset, tgt_dataset, vocab_table, batch_size, sos, eos,
                                                eou=hparams.eou, src_reverse=src_reverse, random_seed=random_seed,
                                                num_dialogue_buckets=num_buckets, src_max_len=src_max_len,
                                                tgt_max_len=tgt_max_len, num_threads=num_threads,
                                                output_buffer_size=output_buffer_size, skip_count=skip_count)

        get_iterator = curry_get_iterator
    else:
        raise ValueError("Unkown architecture", hparams.architecture)

    # Create three models which share parameters through the use of checkpoints
    train_model = create_train_model(model_creator, get_iterator, hparams, scope)
    eval_model = create_eval_model(model_creator, get_iterator, hparams, scope)
    infer_model = inference.create_infer_model(model_creator, get_infer_iterator, hparams, scope)
    # ToDo: adapt for architectures
    # Preload the data to use for sample decoding

    dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src)
    dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt)
    sample_src_data = inference.load_data(dev_src_file)
    sample_tgt_data = inference.load_data(dev_tgt_file)

    summary_name = "train_log"
    model_dir = hparams.out_dir

    # Log and output files
    log_file = os.path.join(out_dir, "log_%d" % time.time())
    log_f = tf.gfile.GFile(log_file, mode="a")
    utils.print_out("# log_file=%s" % log_file, log_f)

    avg_step_time = 0.0

    # Create the configurations for the sessions
    config_proto = utils.get_config_proto(log_device_placement=log_device_placement)
    # Create three sessions, one for each model
    train_sess = tf.Session(target=target_session, config=config_proto, graph=train_model.graph)
    eval_sess = tf.Session(target=target_session, config=config_proto, graph=eval_model.graph)
    infer_sess = tf.Session(target=target_session, config=config_proto, graph=infer_model.graph)

    # Load the train model from checkpoint or create a new one
    with train_model.graph.as_default():
        loaded_train_model, global_step = model_helper.create_or_load_model(train_model.model, model_dir,
                                                                            train_sess, name="train")

    # Summary writer
    summary_writer = tf.summary.FileWriter(
        os.path.join(out_dir, summary_name), train_model.graph)
    # First evaluation
    run_full_eval(
        model_dir, infer_model, infer_sess,
        eval_model, eval_sess, hparams,
        summary_writer, sample_src_data,
        sample_tgt_data)

    last_stats_step = global_step
    last_eval_step = global_step
    last_external_eval_step = global_step

    # This is the training loop.
    # Initialize the hyperparameters for the loop.
    step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0
    checkpoint_total_count = 0.0
    speed, train_ppl = 0.0, 0.0
    start_train_time = time.time()

    utils.print_out(
        "# Start step %d, lr %g, %s" %
        (global_step, loaded_train_model.learning_rate.eval(session=train_sess),
         time.ctime()),
        log_f)

    # epoch_step records where we were within an epoch. Used to skip trained on examples
    skip_count = hparams.batch_size * hparams.epoch_step
    utils.print_out("# Init train iterator, skipping %d elements" % skip_count)
    # Initialize the training iterator
    train_sess.run(
        train_model.iterator.initializer,
        feed_dict={train_model.skip_count_placeholder: skip_count})

    # Train until we reach num_steps.
    while global_step < num_train_steps:
        # Run a step
        start_step_time = time.time()
        try:
            step_result = loaded_train_model.train(train_sess)
            (_, step_loss, step_predict_count, step_summary, global_step,  # The _ is the output of the update op
             step_word_count, batch_size) = step_result
            hparams.epoch_step += 1
        except tf.errors.OutOfRangeError:
            # Finished going through the training dataset.  Go to next epoch.
            hparams.epoch_step = 0
            utils.print_out(
                "# Finished an epoch, step %d. Perform external evaluation" %
                global_step)
            # Decode and print a random sentence
            run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer,
                              sample_src_data, sample_tgt_data)
            # Perform external evaluation to save checkpoints if this is the best for some metric
            dev_scores, test_scores, _ = run_external_evaluation(infer_model, infer_sess, model_dir, hparams,
                                                                 summary_writer, save_on_best_dev=True)
            # Reinitialize the iterator from the beginning
            train_sess.run(train_model.iterator.initializer,
                           feed_dict={train_model.skip_count_placeholder: 0})
            continue

        # Write step summary.
        summary_writer.add_summary(step_summary, global_step)

        # update statistics
        step_time += (time.time() - start_step_time)

        checkpoint_loss += (step_loss * batch_size)
        checkpoint_predict_count += step_predict_count
        checkpoint_total_count += float(step_word_count)

        # Once in a while, we print statistics.
        if global_step - last_stats_step >= steps_per_stats:
            last_stats_step = global_step

            # Print statistics for the previous epoch.
            avg_step_time = step_time / steps_per_stats
            train_ppl = utils.safe_exp(checkpoint_loss / checkpoint_predict_count)
            speed = checkpoint_total_count / (1000 * step_time)
            utils.print_out(
                "  global step %d lr %g "
                "step-time %.2fs wps %.2fK ppl %.2f %s" %
                (global_step,
                 loaded_train_model.learning_rate.eval(session=train_sess),
                 avg_step_time, speed, train_ppl, _get_best_results(hparams)),
                log_f)
            if math.isnan(train_ppl):
                # The model has screwed up
                break

            # Reset timer and loss.
            step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0
            checkpoint_total_count = 0.0

        if global_step - last_eval_step >= steps_per_eval:
            # Perform evaluation. Start by reassigning the last_eval_step variable to the current step
            last_eval_step = global_step
            # Print the progress and add summary
            utils.print_out("# Save eval, global step %d" % global_step)
            utils.add_summary(summary_writer, global_step, "train_ppl", train_ppl)

            # Save checkpoint
            loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "chatbot.ckpt"), global_step=global_step)
            # Decode and print a random sample
            run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer,
                              sample_src_data, sample_tgt_data)
            # Run internal evaluation, and update the ppl variables. The data iterator is instantieted in the method.
            dev_ppl, test_ppl = run_internal_eval(eval_model, eval_sess, model_dir, hparams, summary_writer)

        if global_step - last_external_eval_step >= steps_per_external_eval:
            # Run the external evaluation
            last_external_eval_step = global_step
            # Save checkpoint
            loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "chatbot.ckpt"), global_step=global_step)
            # Decode and print a random sample
            run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer,
                              sample_src_data, sample_tgt_data)
            # Run external evaluation, updating metric scores in the meanwhile. The unneeded output is the global step.
            dev_scores, test_scores, _ = run_external_evaluation(infer_model, infer_sess, model_dir, hparams,
                                                                 summary_writer, save_on_best_dev=True)

    # Done training. Save the model
    loaded_train_model.saver.save(
        train_sess,
        os.path.join(out_dir, "chatbot.ckpt"),
        global_step=global_step)

    result_summary, _, dev_scores, test_scores, dev_ppl, test_ppl = run_full_eval(
        model_dir, infer_model, infer_sess,
        eval_model, eval_sess, hparams,
        summary_writer, sample_src_data,
        sample_tgt_data)
    utils.print_out(
        "# Final, step %d lr %g "
        "step-time %.2f wps %.2fK ppl %.2f, %s, %s" %
        (global_step, loaded_train_model.learning_rate.eval(session=train_sess),
         avg_step_time, speed, train_ppl, result_summary, time.ctime()),
        log_f)
    utils.print_time("# Done training!", start_train_time)

    utils.print_out("# Start evaluating saved best models.")
    for metric in hparams.metrics:
        best_model_dir = getattr(hparams, "best_" + metric + "_dir")
        result_summary, best_global_step, _, _, _, _ = run_full_eval(
            best_model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams,
            summary_writer, sample_src_data, sample_tgt_data)
        utils.print_out("# Best %s, step %d "
                        "step-time %.2f wps %.2fK, %s, %s" %
                        (metric, best_global_step, avg_step_time, speed,
                         result_summary, time.ctime()), log_f)

    summary_writer.close()
    return (dev_scores, test_scores, dev_ppl, test_ppl, global_step)
Example #16
0
def train(hparams, scope=None, target_session=""):
    """Train a translation model."""
    log_device_placement = hparams.log_device_placement
    out_dir = hparams.out_dir
    num_train_steps = hparams.num_train_steps
    steps_per_stats = hparams.steps_per_stats
    steps_per_external_eval = hparams.steps_per_external_eval
    steps_per_eval = 50 * steps_per_stats
    if not steps_per_external_eval:
        steps_per_external_eval = 2 * steps_per_eval

    if not hparams.attention:
        model_creator = nmt_model.Model
    elif hparams.attention_architecture == "standard":
        model_creator = attention_model.AttentionModel
    elif hparams.attention_architecture in ["gnmt", "gnmt_v2"]:
        model_creator = gnmt_model.GNMTModel
    else:
        raise ValueError("Unknown model architecture")

    train_model = model_helper.create_train_model(model_creator, hparams,
                                                  scope)
    eval_model = model_helper.create_eval_model(model_creator, hparams, scope)
    infer_model = model_helper.create_infer_model(model_creator, hparams,
                                                  scope)

    # Preload data for sample decoding.
    dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src)
    dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt)
    sample_src_data = inference.load_data(dev_src_file)
    sample_tgt_data = inference.load_data(dev_tgt_file)

    summary_name = "train_log"
    model_dir = hparams.out_dir

    # Log and output files
    log_file = os.path.join(out_dir, "log_%d" % time.time())
    log_f = tf.gfile.GFile(log_file, mode="a")
    utils.print_out("# log_file=%s" % log_file, log_f)

    avg_step_time = 0.0

    # TensorFlow model
    config_proto = utils.get_config_proto(
        log_device_placement=log_device_placement)

    train_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=train_model.graph)
    eval_sess = tf.Session(target=target_session,
                           config=config_proto,
                           graph=eval_model.graph)
    infer_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=infer_model.graph)

    with train_model.graph.as_default():
        loaded_train_model, global_step = model_helper.create_or_load_model(
            train_model.model, model_dir, train_sess, "train")

    # Summary writer
    summary_writer = tf.summary.FileWriter(os.path.join(out_dir, summary_name),
                                           train_model.graph)

    # First evaluation
    run_full_eval(model_dir, infer_model, infer_sess, eval_model, eval_sess,
                  hparams, summary_writer, sample_src_data, sample_tgt_data)

    last_stats_step = global_step
    last_eval_step = global_step
    last_external_eval_step = global_step

    # This is the training loop.
    stats = init_stats()
    speed, train_ppl = 0.0, 0.0
    start_train_time = time.time()

    utils.print_out(
        "# Start step %d, lr %g, %s" %
        (global_step, loaded_train_model.learning_rate.eval(
            session=train_sess), time.ctime()), log_f)

    # Initialize all of the iterators
    skip_count = hparams.batch_size * hparams.epoch_step
    utils.print_out("# Init train iterator, skipping %d elements" % skip_count)
    train_sess.run(train_model.iterator.initializer,
                   feed_dict={train_model.skip_count_placeholder: skip_count})

    while global_step < num_train_steps:
        ### Run a step ###
        start_time = time.time()
        try:
            step_result = loaded_train_model.train(train_sess)
            hparams.epoch_step += 1
        except tf.errors.OutOfRangeError as exp:
            #print(exp)
            # Finished going through the training dataset.  Go to next epoch.
            hparams.epoch_step = 0
            utils.print_out(
                "# Finished an epoch, step %d. Perform external evaluation" %
                global_step)
            # run_sample_decode(infer_model, infer_sess,
            #                   model_dir, hparams, summary_writer, sample_src_data,
            #                   sample_tgt_data)
            # dev_scores, test_scores, _ = run_external_eval(
            #     infer_model, infer_sess, model_dir,
            #     hparams, summary_writer)
            train_sess.run(train_model.iterator.initializer,
                           feed_dict={train_model.skip_count_placeholder: 0})
            print(steps_per_stats)
            print(steps_per_eval)
            continue

        # Write step summary and accumulate statistics
        global_step = update_stats(stats, summary_writer, start_time,
                                   step_result)

        # Once in a while, we print statistics.
        if global_step - last_stats_step >= steps_per_stats:
            last_stats_step = global_step
            is_overflow = check_stats(stats, global_step, steps_per_stats,
                                      hparams, log_f)
            if is_overflow:
                break

            # Reset statistics
            stats = init_stats()

        if global_step - last_eval_step >= steps_per_eval:
            last_eval_step = global_step

            utils.print_out("# Save eval, global step %d" % global_step)
            utils.add_summary(summary_writer, global_step, "train_ppl",
                              train_ppl)

            # Save checkpoint
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(out_dir,
                                                       "translate.ckpt"),
                                          global_step=global_step)

            # Evaluate on dev/test
            run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                              summary_writer, sample_src_data, sample_tgt_data)
            dev_ppl, test_ppl = run_internal_eval(eval_model, eval_sess,
                                                  model_dir, hparams,
                                                  summary_writer)

            dev_scores, test_scores, _ = run_external_eval(
                infer_model, infer_sess, model_dir, hparams, summary_writer)

        if global_step - last_external_eval_step >= steps_per_external_eval:
            last_external_eval_step = global_step
            # Save checkpoint
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(out_dir,
                                                       "translate.ckpt"),
                                          global_step=global_step)
            run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                              summary_writer, sample_src_data, sample_tgt_data)
            dev_scores, test_scores, _ = run_external_eval(
                infer_model, infer_sess, model_dir, hparams, summary_writer)

    # Done training
    loaded_train_model.saver.save(train_sess,
                                  os.path.join(out_dir, "translate.ckpt"),
                                  global_step=global_step)

    result_summary, _, dev_scores, test_scores, dev_ppl, test_ppl = run_full_eval(
        model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams,
        summary_writer, sample_src_data, sample_tgt_data)
    utils.print_out(
        "# Final, step %d lr %g "
        "step-time %.2f wps %.2fK ppl %.2f, %s, %s" %
        (global_step,
         loaded_train_model.learning_rate.eval(session=train_sess),
         avg_step_time, speed, train_ppl, result_summary, time.ctime()), log_f)
    utils.print_time("# Done training!", start_train_time)

    summary_writer.close()

    utils.print_out("# Start evaluating saved best models.")
    for metric in hparams.metrics:
        best_model_dir = getattr(hparams, "best_" + metric + "_dir")
        summary_writer = tf.summary.FileWriter(
            os.path.join(best_model_dir, summary_name), infer_model.graph)
        result_summary, best_global_step, _, _, _, _ = run_full_eval(
            best_model_dir, infer_model, infer_sess, eval_model, eval_sess,
            hparams, summary_writer, sample_src_data, sample_tgt_data)
        utils.print_out(
            "# Best %s, step %d "
            "step-time %.2f wps %.2fK, %s, %s" %
            (metric, best_global_step, avg_step_time, speed, result_summary,
             time.ctime()), log_f)
        summary_writer.close()

    return (dev_scores, test_scores, dev_ppl, test_ppl, global_step)
Example #17
0
def train(hparams, scope=None, target_session=""):
    log_device_placement = hparams.log_device_placement
    out_dir = hparams.out_dir
    num_train_steps = hparams.num_train_steps
    steps_per_stats = hparams.steps_per_stats
    steps_per_external_eval = hparams.steps_per_external_eval
    steps_per_eval = 10 * steps_per_stats
    avg_ckpts = hparams.avg_ckpts

    if not steps_per_external_eval:
        steps_per_external_eval = 5 * steps_per_eval

    if not hparams.attention:
        model_creator = nmt_model.Model
    else:
        if hparams.attention_architecture == "standard":
            model_creator = attention_model.AttentionModel
        else:
            raise ValueError("Unknown attention architecture %s" %
                             hparams.attention_architecture)

    train_model = model_util.create_train_model(model_creator, hparams, scope)
    eval_model = model_util.create_eval_model(model_creator, hparams, scope)
    infer_model = model_util.create_infer_model(model_creator, hparams, scope)

    dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src)
    dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt)
    sample_src_data = inference.load_data(dev_src_file)
    sample_tgt_data = inference.load_data(dev_tgt_file)

    summary_name = "train_log"
    model_dir = hparams.out_dir

    log_file = os.path.join(out_dir, "log_%d" % time.time())
    log_f = tf.gfile.GFile(log_file, mode="a")
    utils.print_out("# log_file=%s" % log_file, log_f)

    config_proto = utils.get_config_proto(
        log_device_placement=log_device_placement,
        num_intra_threads=hparams.num_intra_threads,
        num_inter_threads=hparams.num_inter_threads)
    train_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=train_model.graph)
    eval_sess = tf.Session(target=target_session,
                           config=config_proto,
                           graph=eval_model.graph)
    infer_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=infer_model.graph)

    with train_model.graph.as_default():
        loaded_train_model, global_step = model_util.create_or_load_model(
            train_model.model, model_dir, train_sess, "train")

    summary_writer = tf.summary.FileWriter(os.path.join(out_dir, summary_name),
                                           train_model.graph)

    run_full_eval(model_dir, infer_model, infer_sess, eval_model, eval_sess,
                  hparams, summary_writer, sample_src_data, sample_tgt_data,
                  avg_ckpts)

    last_stats_step = global_step
    last_eval_step = global_step
    last_external_eval_step = global_step

    stats, info, start_train_time = before_train(loaded_train_model,
                                                 train_model, train_sess,
                                                 global_step, hparams, log_f)

    while global_step < num_train_steps:
        start_time = time.time()
        try:
            step_result = loaded_train_model.train(train_sess)
            hparams.epoch_step += 1
        except tf.errors.OutOfRangeError:
            hparams.epoch_step = 0
            utils.print_out(
                "# Finished an epoch, step %d. Perform external evaluation" %
                global_step)
            run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                              summary_writer, sample_src_data, sample_tgt_data)
            run_external_eval(infer_model, infer_sess, model_dir, hparams,
                              summary_writer)

            if avg_ckpts:
                run_avg_external_eval(infer_model, infer_sess, model_dir,
                                      hparams, summary_writer, global_step)

            train_sess.run(train_model.iterator.initializer,
                           feed_dict={train_model.skip_count_placeholder: 0})
            continue

        global_step, info["learning_rate"], step_summary = update_stats(
            stats, start_time, step_result)
        summary_writer.add_summary(step_summary, global_step)

        if global_step - last_stats_step >= steps_per_stats:
            last_stats_step = global_step
            is_overflow = process_stats(stats, info, global_step,
                                        steps_per_stats, log_f)
            print_step_info("  ", global_step, info,
                            _get_best_results(hparams), log_f)
            if is_overflow:
                break

            stats = init_stats()

        if global_step - last_eval_step >= steps_per_eval:
            last_eval_step = global_step
            utils.print_out("# Save eval, global step %d" % global_step)
            utils.add_summary(summary_writer, global_step, "train_perplexity",
                              info["train_perplexity"])

            loaded_train_model.saver.save(train_sess,
                                          os.path.join(out_dir,
                                                       "translate.ckpt"),
                                          global_step=global_step)

            run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                              summary_writer, sample_src_data, sample_tgt_data)
            run_internal_eval(eval_model, eval_sess, model_dir, hparams,
                              summary_writer)

        if global_step - last_external_eval_step >= steps_per_external_eval:
            last_external_eval_step = global_step

            loaded_train_model.saver.save(train_sess,
                                          os.path.join(out_dir,
                                                       "translate.ckpt"),
                                          global_step=global_step)
            run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                              summary_writer, sample_src_data, sample_tgt_data)
            run_external_eval(infer_model, infer_sess, model_dir, hparams,
                              summary_writer)

            if avg_ckpts:
                run_avg_external_eval(infer_model, infer_sess, model_dir,
                                      hparams, summary_writer, global_step)

    loaded_train_model.saver.save(train_sess,
                                  os.path.join(out_dir, "translate.ckpt"),
                                  global_step=global_step)

    result_summary, _, final_eval_metrics = run_full_eval(
        model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams,
        summary_writer, sample_src_data, sample_tgt_data, avg_ckpts)
    print_step_info("# Final, ", global_step, info, result_summary, log_f)
    utils.print_time("# Done training!", start_train_time)

    summary_writer.close()

    utils.print_out("# Start evaluating saved best models.")
    for metric in hparams.metrics:
        best_model_dir = getattr(hparams, "best_" + metric + "_dir")
        summary_writer = tf.summary.FileWriter(
            os.path.join(best_model_dir, summary_name), infer_model.graph)
        result_summary, best_global_step, _ = run_full_eval(
            best_model_dir, infer_model, infer_sess, eval_model, eval_sess,
            hparams, summary_writer, sample_src_data, sample_tgt_data)
        print_step_info("# Best %s, " % metric, best_global_step, info,
                        result_summary, log_f)
        summary_writer.close()

        if avg_ckpts:
            best_model_dir = getattr(hparams, "avg_best_" + metric + "_dir")
            summary_writer = tf.summary.FileWriter(
                os.path.join(best_model_dir, summary_name), infer_model.graph)
            result_summary, best_global_step, _ = run_full_eval(
                best_model_dir, infer_model, infer_sess, eval_model, eval_sess,
                hparams, summary_writer, sample_src_data, sample_tgt_data)
            print_step_info("# Averaged Best %s, " % metric, best_global_step,
                            info, result_summary, log_f)
            summary_writer.close()

    return final_eval_metrics, global_step
Example #18
0
def train(hparams, scope=None, target_session="", single_cell_fn=None):
    """Train a translation model."""

    log_device_placement = hparams.log_device_placement
    out_dir = hparams.out_dir
    num_train_steps = hparams.num_train_steps
    steps_per_stats = hparams.steps_per_stats

    if hparams.eval_on_fly:
        steps_per_external_eval = hparams.steps_per_external_eval
        steps_per_eval = 10 * steps_per_stats

        if not steps_per_external_eval:
            steps_per_external_eval = 2 * steps_per_eval
    else:
        steps_per_snapshot = hparams.snapshot_interval

    if not hparams.attention:
        model_creator = nmt_model.Model
    elif hparams.attention_architecture == "standard":
        model_creator = attention_model.AttentionModel
    elif hparams.attention_architecture in ["gnmt", "gnmt_v2"]:
        model_creator = gnmt_model.GNMTModel
    else:
        raise ValueError("Unknown model architecture")

    train_model = create_train_model(model_creator, hparams, scope,
                                     single_cell_fn)

    if hparams.eval_on_fly:
        eval_model = create_eval_model(model_creator, hparams, scope,
                                       single_cell_fn)
        infer_model = inference.create_infer_model(model_creator, hparams,
                                                   scope, single_cell_fn)

        # Preload data for sample decoding.
        dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src)
        dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt)
        sample_src_data = inference.load_data(dev_src_file)
        sample_tgt_data = inference.load_data(dev_tgt_file)

    summary_name = "train_log"
    model_dir = hparams.out_dir

    # Log and output files
    log_file = os.path.join(out_dir, "log_%d" % time.time())
    log_f = tf.gfile.GFile(log_file, mode="a")
    utils.print_out("# log_file=%s" % log_file, log_f)

    avg_step_time = 0.0

    # TensorFlow model
    config_proto = utils.get_config_proto(
        log_device_placement=log_device_placement)

    train_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=train_model.graph)

    if hparams.eval_on_fly:
        eval_sess = tf.Session(target=target_session,
                               config=config_proto,
                               graph=eval_model.graph)
        infer_sess = tf.Session(target=target_session,
                                config=config_proto,
                                graph=infer_model.graph)

    with train_model.graph.as_default():
        model_helper.initialize_cnn(train_model.model, train_sess)
        loaded_train_model, global_step = model_helper.create_or_load_model(
            train_model.model, model_dir, train_sess, "train")

    # Summary writer
    summary_writer = tf.summary.FileWriter(os.path.join(out_dir, summary_name),
                                           train_model.graph)

    # First evaluation
    if hparams.eval_on_fly:
        run_full_eval(model_dir, infer_model, infer_sess, eval_model,
                      eval_sess, hparams, summary_writer, sample_src_data,
                      sample_tgt_data)

    last_stats_step = global_step

    if hparams.eval_on_fly:
        last_eval_step = global_step
        last_external_eval_step = global_step
    else:
        last_snapshot_step = global_step

    # This is the training loop.
    step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0
    checkpoint_total_count = 0.0
    speed, train_ppl = 0.0, 0.0
    start_train_time = time.time()

    utils.print_out(
        "# Start step %d, lr %g, %s" %
        (global_step, loaded_train_model.learning_rate.eval(
            session=train_sess), time.ctime()), log_f)

    # Initialize all of the iterators
    skip_count = hparams.batch_size * hparams.epoch_step
    utils.print_out("# Init train iterator, skipping %d elements" % skip_count)
    train_sess.run(train_model.iterator.initializer,
                   feed_dict={train_model.skip_count_placeholder: skip_count})

    while global_step < num_train_steps:
        ### Run a step ###
        start_time = time.time()
        try:
            step_result = loaded_train_model.train(train_sess)
            (_, step_loss, step_predict_count, step_summary, global_step,
             step_word_count, batch_size) = step_result
            hparams.epoch_step += 1
        except tf.errors.OutOfRangeError:
            # Finished going through the training dataset.  Go to next epoch.
            hparams.epoch_step = 0
            utils.print_out(
                "# Finished an epoch, step %d. Perform external evaluation" %
                global_step)

            if hparams.eval_on_fly:
                run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                                  summary_writer, sample_src_data,
                                  sample_tgt_data)
                dev_scores, test_scores, _ = run_external_eval(
                    infer_model, infer_sess, model_dir, hparams,
                    summary_writer)

            train_sess.run(train_model.iterator.initializer,
                           feed_dict={train_model.skip_count_placeholder: 0})
            continue

        # Write step summary.
        summary_writer.add_summary(step_summary, global_step)

        # update statistics
        step_time += (time.time() - start_time)

        checkpoint_loss += (step_loss * batch_size)
        checkpoint_predict_count += step_predict_count
        checkpoint_total_count += float(step_word_count)

        # Once in a while, we print statistics.
        if global_step - last_stats_step >= steps_per_stats:
            last_stats_step = global_step

            # Print statistics for the previous epoch.
            avg_step_time = step_time / steps_per_stats
            train_ppl = utils.safe_exp(checkpoint_loss /
                                       checkpoint_predict_count)
            speed = checkpoint_total_count / (1000 * step_time)
            utils.print_out(
                "  global step %d lr %g step-time %.2fs wps %.2fK ppl %.2f %s"
                %
                (global_step,
                 loaded_train_model.learning_rate.eval(session=train_sess),
                 avg_step_time, speed, train_ppl, _get_best_results(hparams)),
                log_f)
            if math.isnan(train_ppl):
                break

            # Reset timer and loss.
            step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0
            checkpoint_total_count = 0.0

        ##
        if (not hparams.eval_on_fly) and (global_step - last_snapshot_step >=
                                          steps_per_snapshot):
            last_snapshot_step = global_step
            utils.print_out("# Cihan: Saving Snapshot, global step %d" %
                            global_step)
            utils.add_summary(summary_writer, global_step, "train_ppl",
                              train_ppl)

            # Save checkpoint
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(out_dir,
                                                       "translate.ckpt"),
                                          global_step=global_step)

        if hparams.eval_on_fly and (global_step - last_eval_step >=
                                    steps_per_eval):
            last_eval_step = global_step

            utils.print_out("# Save eval, global step %d" % global_step)
            utils.add_summary(summary_writer, global_step, "train_ppl",
                              train_ppl)

            # Save checkpoint
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(out_dir,
                                                       "translate.ckpt"),
                                          global_step=global_step)

            # Evaluate on dev/test
            run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                              summary_writer, sample_src_data, sample_tgt_data)
            dev_ppl, test_ppl = run_internal_eval(eval_model, eval_sess,
                                                  model_dir, hparams,
                                                  summary_writer)

        if hparams.eval_on_fly and (global_step - last_external_eval_step >=
                                    steps_per_external_eval):
            last_external_eval_step = global_step

            # Save checkpoint
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(out_dir,
                                                       "translate.ckpt"),
                                          global_step=global_step)
            run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                              summary_writer, sample_src_data, sample_tgt_data)
            dev_scores, test_scores, _ = run_external_eval(
                infer_model, infer_sess, model_dir, hparams, summary_writer)

    # Done training
    loaded_train_model.saver.save(train_sess,
                                  os.path.join(out_dir, "translate.ckpt"),
                                  global_step=global_step)

    if hparams.eval_on_fly:
        result_summary, _, dev_scores, test_scores, dev_ppl, test_ppl = run_full_eval(
            model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams,
            summary_writer, sample_src_data, sample_tgt_data)

        utils.print_out(
            "# Final, step %d lr %g step-time %.2f wps %.2fK ppl %.2f, %s, %s"
            % (global_step,
               loaded_train_model.learning_rate.eval(session=train_sess),
               avg_step_time, speed, train_ppl, result_summary, time.ctime()),
            log_f)
        utils.print_time("# Done training!", start_train_time)

        utils.print_out("# Start evaluating saved best models.")
        for metric in hparams.metrics:
            best_model_dir = getattr(hparams, "best_" + metric + "_dir")
            result_summary, best_global_step, _, _, _, _ = run_full_eval(
                best_model_dir, infer_model, infer_sess, eval_model, eval_sess,
                hparams, summary_writer, sample_src_data, sample_tgt_data)
            utils.print_out(
                "# Best %s, step %d step-time %.2f wps %.2fK, %s, %s" %
                (metric, best_global_step, avg_step_time, speed,
                 result_summary, time.ctime()), log_f)

    summary_writer.close()

    if hparams.eval_on_fly:
        return dev_scores, test_scores, dev_ppl, test_ppl, global_step
    else:
        return global_step
Example #19
0
def train():
	"""Train a translation model."""
	create_new_model = params['create_new_model']
	out_dir = params['out_dir']
	model_creator = nmt_model.Model # Create model graph
	summary_name = "train_log"

	# Setting up session and initilize input data iterators
	src_file = params['src_data_file']
	tgt_file = params['tgt_data_file']
	dev_src_file = params['dev_src_file']
	dev_tgt_file = params['dev_tgt_file']
	test_src_file = params['test_src_file']
	test_tgt_file = params['test_tgt_file']


	char_vocab_file = params['enc_char_map_path']
	src_vocab_file = params['src_vocab_file']
	tgt_vocab_file = params['tgt_vocab_file']
	if(src_vocab_file == '' or src_vocab_file == ''):
		raise ValueError("vocab_file '%s' not given in params.") 

	graph = tf.Graph()

  	# Log and output files
	log_file = os.path.join(out_dir, "log_%d" % time.time())
	log_f = tf.gfile.GFile(log_file, mode="a")
	utils.print_out("# log_file=%s" % log_file, log_f)


	# Model run params
	num_epochs = params['num_epochs']
	batch_size = params['batch_size']
	steps_per_stats = params['steps_per_stats']

	utils.print_out("# Epochs=%s, Batch Size=%s, Steps_per_Stats=%s" % (num_epochs, batch_size, steps_per_stats), None)

	with graph.as_default():
		src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(src_vocab_file, tgt_vocab_file, params['share_vocab'])
		char_vocab_table = vocab_utils.get_char_table(char_vocab_file)
		reverse_target_table = lookup_ops.index_to_string_table_from_file(tgt_vocab_file, default_value=params['unk'])

		src_dataset = tf.data.TextLineDataset(src_file)
		tgt_dataset = tf.data.TextLineDataset(tgt_file)

		batched_iter = iterator_utils.get_iterator(src_dataset,
											   tgt_dataset,
											   char_vocab_table,
											   src_vocab_table,
											   tgt_vocab_table,
											   batch_size=batch_size,
											   sos=params['sos'],
											   eos=params['eos'],
											   char_pad = params['char_pad'],
											   num_buckets=params['num_buckets'],
											   num_epochs = params['num_epochs'],
											   src_max_len=params['src_max_len'],
											   tgt_max_len=params['tgt_max_len'],
											   src_char_max_len = params['char_max_len']
											   )

		# Summary writer
		summary_writer = tf.summary.FileWriter(os.path.join(out_dir, summary_name),graph)


		# Preload validation data for decoding.
		dev_src_dataset = tf.data.TextLineDataset(dev_src_file)
		dev_tgt_dataset = tf.data.TextLineDataset(dev_tgt_file)
		dev_batched_iterator = iterator_utils.get_iterator(dev_src_dataset,
														   dev_tgt_dataset,
														   char_vocab_table,
														   src_vocab_table,
														   tgt_vocab_table,
														   batch_size=batch_size,
														   sos=params['sos'],
														   eos=params['eos'],
														   char_pad = params['char_pad'],
														   num_buckets=params['num_buckets'],
														   num_epochs = params['num_epochs'],
														   src_max_len=params['src_max_len'],
														   tgt_max_len=params['tgt_max_len'],
														   src_char_max_len = params['char_max_len']
														   )

		# Preload test data for decoding.
		test_src_dataset = tf.data.TextLineDataset(test_src_file)
		test_tgt_dataset = tf.data.TextLineDataset(test_tgt_file)
		test_batched_iterator = iterator_utils.get_iterator(test_src_dataset,
														   test_tgt_dataset,
														   char_vocab_table,
														   src_vocab_table,
														   tgt_vocab_table,
														   batch_size=batch_size,
														   sos=params['sos'],
														   eos=params['eos'],
														   char_pad = params['char_pad'],
														   num_buckets=params['num_buckets'],
														   num_epochs = params['num_epochs'],
														   src_max_len=params['src_max_len'],
														   tgt_max_len=params['tgt_max_len'],
														   src_char_max_len = params['char_max_len']
														   )

		config_proto = utils.get_config_proto(log_device_placement=params['log_device_placement'])
		sess = tf.Session(config=config_proto)


		with sess.as_default():
			

			train_model = model_creator(mode = params['mode'],
										train_iterator = batched_iter,
										val_iterator = dev_batched_iterator,
										char_vocab_table = char_vocab_table,
										source_vocab_table=src_vocab_table,
										target_vocab_table=tgt_vocab_table,
										reverse_target_table = reverse_target_table)

			loaded_train_model, global_step = create_or_load_model(train_model, params['out_dir'],session=sess,name="train",
																	log_f = log_f, create=create_new_model)
			
			sess.run([batched_iter.initializer,dev_batched_iterator.initializer, test_batched_iterator.initializer])


			start_train_time = time.time()
			utils.print_out("# Start step %d, lr %g, %s" %(global_step, loaded_train_model.learning_rate.eval(session=sess), time.ctime()), log_f)
			
			# Reset statistics
			stats = init_stats()

			steps_per_epoch = int(np.ceil(utils.get_file_row_size(src_file) / batch_size))
			utils.print_out("Total steps per epoch: %d" % steps_per_epoch)

			def train_step(model, sess):	
				return model.train(sess)
			def dev_step(model, sess):
				total_steps = int(np.ceil(utils.get_file_row_size(dev_src_file) / batch_size ))
				total_dev_loss = 0.0
				total_accuracy = 0.0
				for _ in range(total_steps):
					dev_result_step = model.dev(sess)
					dev_softmax_scores, dev_loss, tgt_output_ids,_,_,_,_ = dev_result_step
					total_dev_loss += dev_loss * params['batch_size']
					total_accuracy += evaluation_utils._accuracy(dev_softmax_scores, tgt_output_ids,  None, None)
				return (total_dev_loss/total_steps, total_accuracy/total_steps)


			for epoch_step in range(num_epochs): 
				for curr_step in range(int(np.ceil(steps_per_epoch))):
					start_time = time.time()
					step_result = train_step(loaded_train_model, sess)
					global_step = update_stats(stats, summary_writer, start_time, step_result)

    				# Logging Step
					if(curr_step % params['steps_per_stats'] == 0):
						check_stats(stats, global_step, steps_per_stats, log_f)



					# Evaluation
					if(curr_step % params['steps_per_devRun'] == 0):
						dev_step_loss, dev_step_acc = dev_step(loaded_train_model, sess)
						utils.print_out("Dev Step total loss, Accuracy: %f, %f" % (dev_step_loss, dev_step_acc), log_f)

				utils.print_out("# Finished an epoch, epoch completed %d" % epoch_step)
				loaded_train_model.saver.save(sess,  os.path.join(out_dir, "translate.ckpt"), global_step=global_step)
				dev_step_loss = dev_step(loaded_train_model, sess)


			utils.print_time("# Done training!", start_train_time)
			summary_writer.close()