コード例 #1
0
ファイル: inference.py プロジェクト: herbertchen1/MANNs4NMT
def inference(ckpt,
              inference_input_file,
              inference_output_file,
              hparams,
              num_workers=1,
              jobid=0,
              scope=None):
    """Perform translation."""
    if hparams.inference_indices:
        assert num_workers == 1

    if not hparams.attention:
        model_creator = nmt_model.Model
    elif hparams.attention_architecture == "standard":
        model_creator = attention_model.AttentionModel
    elif hparams.attention_architecture in ["gnmt", "gnmt_v2"]:
        model_creator = gnmt_model.GNMTModel
    else:
        raise ValueError("Unknown model architecture")
    infer_model = model_helper.create_infer_model(model_creator, hparams,
                                                  scope)

    if num_workers == 1:
        single_worker_inference(infer_model, ckpt, inference_input_file,
                                inference_output_file, hparams)
    else:
        multi_worker_inference(infer_model,
                               ckpt,
                               inference_input_file,
                               inference_output_file,
                               hparams,
                               num_workers=num_workers,
                               jobid=jobid)
コード例 #2
0
ファイル: infer.py プロジェクト: zhangshuai881020/chinese_nlp
def infer(ckpt, inference_input_file, inference_output_file, hparams):
    """
    Perform translation.
    """
    model_creator = gnmt_model.GNMTModel
    infer_model = model_helper.create_infer_model(model_creator, hparams)

    # Read data
    infer_data = utils.load_data(inference_input_file)

    config_proto = tf.ConfigProto()
    config_proto.gpu_options.allow_growth = True
    with tf.Session(
        graph=infer_model.graph, config=config_proto) as sess:

        loaded_infer_model = model_helper.load_model(
            infer_model.model, ckpt, sess, "infer")
        sess.run(
            infer_model.iterator.initializer,
            feed_dict={
                infer_model.src_placeholder: infer_data,
                infer_model.batch_size_placeholder: hparams.infer_batch_size
            })
        # Decode
        utils.log("Start decoding")
        loaded_infer_model.decode_and_evaluate(
            "infer",
            sess,
            inference_output_file,
            ref_file=None,
            beam_width=hparams.beam_width,
            tgt_eos=hparams.eos,
            num_translations_per_input=hparams.num_translations_per_input)
コード例 #3
0
def evaluate(hparams, ckpt):
    if hparams.model_architecture == "rnn-model": model_creator = model.RNN
    else: raise ValueError("Unknown model architecture. Only simple_rnn is supported so far.")

    if hparams.val_target_path:
        eval_model = model_helper.create_eval_model(model_creator, hparams, tf.contrib.learn.ModeKeys.EVAL)
        eval_sess = tf.Session(config=utils.get_config_proto(), graph=eval_model.graph)
        with eval_model.graph.as_default():
            loaded_eval_model = model_helper.load_model(eval_model.model, eval_sess, "evaluation", ckpt)
        iterator_feed_dict={
            eval_model.input_file_placeholder: hparams.eval_input_path,
            eval_model.output_file_placeholder: hparams.eval_target_path
        }
        eval_loss = eval(loaded_eval_model, eval_sess, eval_model.iterator, iterator_feed_dict)
        print("Eval loss: %.3f"%eval_loss)
    print("Starting predictions:")

    prediction_model = model_helper.create_infer_model(model_creator, hparams, tf.contrib.learn.ModeKeys.INFER)
    prediction_sess = tf.Session(config=utils.get_config_proto(), graph=prediction_model.graph)
    with prediction_model.graph.as_default():
        loaded_prediction_model = model_helper.load_model(prediction_model.model, prediction_sess, "prediction", ckpt)
        iterator_feed_dict = {
            prediction_model.input_file_placeholder: hparams.val_input_path,
        }
    predictions=predict(loaded_prediction_model, prediction_sess, prediction_model.iterator, iterator_feed_dict)
    np.savetxt(os.path.join(hparams.eval_output_folder, "classes.txt"), predictions["classes"])
    np.savetxt(os.path.join(hparams.eval_output_folder, "probabilities.txt"), predictions["probabilities"])
コード例 #4
0
def infer(hparams, ckpt_dir, scope=None, target_session=""):
    output_dir = os.path.join(hparams.base_dir, "infer")
    if not misc.check_file_existence(output_dir):
        tf.gfile.MakeDirs(output_dir)
    model_creator = get_model_creator(hparams.model_type)
    infer_model = helper.create_infer_model(model_creator, hparams, scope)
    infer_sess, loaded_infer_model, global_step = start_sess_and_load_model(
        infer_model, ckpt_dir)
    tf.logging.info("Restore model from global step %d" % global_step)
    # Summary
    summary_name = "infer_summary"
    summary_path = os.path.join(output_dir, summary_name)
    if not tf.gfile.Exists(summary_path):
        tf.gfile.MakeDirs(summary_path)
    summary_writer = tf.summary.FileWriter(
        summary_path, infer_model.graph)
    infer_results = []
    tf.logging.info("Ready to infer")
    start_time = time.time()
    step = 0
    while True:
        try:
            tf.logging.info("Start infer step:%d" % step)
            results = loaded_infer_model.infer(infer_sess)
            summary_writer.add_summary(results.summary, global_step)
            infer_results.append(results.detected_images)
            step += 1
        except tf.errors.OutOfRangeError:
            tf.logging.info("Finish infer <time:%d>" % (start_time-time.time()))
            break
    for result in infer_results:
        plt.show(result)
コード例 #5
0
def inference(ckpt_path,
              inference_input_file,
              inference_output_file,
              hparams,
              num_workers=1,
              jobid=0,
              scope=None):
    """Perform translation."""
    if hparams.inference_indices:
        assert num_workers == 1

    model_creator = get_model_creator(hparams)
    infer_model = model_helper.create_infer_model(model_creator, hparams,
                                                  scope)
    sess, loaded_infer_model = start_sess_and_load_model(
        infer_model, ckpt_path)
    translation = ''
    if num_workers == 1:
        translation = single_worker_inference(sess, infer_model,
                                              loaded_infer_model,
                                              inference_input_file,
                                              inference_output_file, hparams)
    else:
        multi_worker_inference(sess,
                               infer_model,
                               loaded_infer_model,
                               inference_input_file,
                               inference_output_file,
                               hparams,
                               num_workers=num_workers,
                               jobid=jobid)
    sess.close()
    return translation
コード例 #6
0
def inference(ckpt,
              inference_input_file,
              inference_output_file,
              hparams,
              num_workers=1,
              jobid=0,
              scope=None):
    """Perform translation."""
    if hparams.inference_indices:
        assert num_workers == 1

    model_creator = AttentionModel

    infer_model = model_helper.create_infer_model(model_creator, hparams,
                                                  scope)

    if num_workers == 1:
        single_worker_inference(infer_model, ckpt, inference_input_file,
                                inference_output_file, hparams)
    else:
        multi_worker_inference(infer_model,
                               ckpt,
                               inference_input_file,
                               inference_output_file,
                               hparams,
                               num_workers=num_workers,
                               jobid=jobid)
コード例 #7
0
ファイル: train.py プロジェクト: ml-lab/Pun-GAN
def run_sample_decode_pungan_prepare(hparams, scope=None, target_session=""):
    log_device_placement = hparams.log_device_placement
    out_dir = hparams.out_dir
    if not hparams.attention:  # choose this model
        model_creator = nmt_model.Model
    else:  # Attention
        if (hparams.encoder_type == "gnmt"
                or hparams.attention_architecture in ["gnmt", "gnmt_v2"]):
            model_creator = gnmt_model.GNMTModel
        elif hparams.attention_architecture == "standard":
            model_creator = attention_model.AttentionModel
        else:
            raise ValueError("Unknown attention architecture %s" %
                             hparams.attention_architecture)
    infer_model = model_helper.create_infer_model(model_creator, hparams,
                                                  scope)
    wsd_src_file = "%s" % (hparams.sample_prefix)
    wsd_src_data = inference.load_data(wsd_src_file)
    model_dir = hparams.out_dir
    # TensorFlow model
    config_proto = utils.get_config_proto(
        log_device_placement=log_device_placement,
        num_intra_threads=hparams.num_intra_threads,
        num_inter_threads=hparams.num_inter_threads)
    infer_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=infer_model.graph)

    eval_result = run_sample_decode_pungan(infer_model, infer_sess, model_dir,
                                           hparams, wsd_src_data)
    print('eval_result', eval_result)
    print('eval_result.len', len(eval_result))
    '''
    for i in range(0,256):
        if i%32 == 0:
            print('\n')
        print('eval_result[',i,']',eval_result[i])
    '''

    print('wsd_src_data', wsd_src_data)
    print('wsd_src_data.len', len(wsd_src_data))
    #len=16 wsd_src_data [u'problem%1:10:00::', u'problem%1:26:00::', u'drive%1:04:00::', u'drive%1:04:03::', u'identity%1:07:00::', u'identity%1:24:01::', u'point%1:10:01::', u'point%1:06:00::', u'tension%1:26:01::', u'tension%1:26:03::', u'log%2:32:00::', u'log%2:35:00::', u'fan%1:06:00::', u'fan%1:18:00::', u'file%2:32:00::', u'file%2:35:00::']

    eval_result_new = []
    for block in range(len(eval_result) / (2 * hparams.sample_size)):
        src_word1, src_word2 = wsd_src_data[2 *
                                            block], wsd_src_data[2 * block + 1]
        for sent_id in range(block * hparams.sample_size,
                             (block + 1) * hparams.sample_size):
            tgt_sent = src_word1.decode().encode(
                'utf-8') + ' ' + eval_result[sent_id]
            eval_result_new.append(tgt_sent)
    return wsd_src_data, eval_result_new
コード例 #8
0
def infer_fn(hparams, identity, scope=None, extra_args=None, target_session=""):
  """main entry point for inference and evaluation."""
  # create infer and eval models
  infer_model = model_helper.create_infer_model(
      diag_model.Model, hparams, scope, extra_args=extra_args)
  eval_model = model_helper.create_eval_model(diag_model.Model, hparams, scope)
  config_proto = utils.get_config_proto(
      log_device_placement=hparams.log_device_placement,
      allow_soft_placement=True)
  # create the eval session
  eval_sess = tf.Session(
      target=target_session, config=config_proto, graph=eval_model.graph)

  secondary_fn_tmp(hparams, identity, hparams.out_dir, infer_model, eval_model,
                   eval_sess, "infer", single_worker_inference)
コード例 #9
0
def inference(ckpt,
              inference_input_file,
              inference_output_file,
              hparams,
              num_workers=1):
    """Perform inference."""
    model_creator = model.Model
    infer_model = model_helper.create_infer_model(model_creator, hparams)

    single_worker_inference(
        infer_model,
        ckpt,
        inference_input_file,
        inference_output_file,
        hparams)
コード例 #10
0
    def _createTestInferCheckpoint(self, hparams, out_dir):
        if not hparams.attention:
            model_creator = nmt_model.Model
        elif hparams.attention_architecture == "standard":
            model_creator = attention_model.AttentionModel
        elif hparams.attention_architecture in ["gnmt", "gnmt_v2"]:
            model_creator = gnmt_model.GNMTModel
        else:
            raise ValueError("Unknown model architecture")

        infer_model = model_helper.create_infer_model(model_creator, hparams)
        with self.test_session(graph=infer_model.graph) as sess:
            loaded_model, global_step = model_helper.create_or_load_model(
                infer_model.model, out_dir, sess, "infer_name")
            ckpt = loaded_model.saver.save(sess,
                                           os.path.join(
                                               out_dir, "translate.ckpt"),
                                           global_step=global_step)
        return ckpt
コード例 #11
0
ファイル: inference.py プロジェクト: ml-lab/Pun-GAN
def inference(ckpt,
              inference_input_file,
              inference_output_file,
              hparams,
              num_workers=1,
              jobid=0,
              scope=None):
    """Perform translation."""
    if hparams.inference_indices:
        assert num_workers == 1

    if not hparams.attention:
        model_creator = nmt_model.Model
    elif hparams.attention_architecture == "standard":
        model_creator = attention_model.AttentionModel
    elif hparams.attention_architecture in ["gnmt", "gnmt_v2"]:
        model_creator = gnmt_model.GNMTModel
    else:
        raise ValueError("Unknown model architecture")
    infer_model = model_helper.create_infer_model(model_creator, hparams,
                                                  scope)
    #emb_matrix = model_helper._create_or_load_embed("embedding_encoder", hparams.src_vocab_file, hparams.src_embed_file,
    #hparams.src_vocab_size, hparams.batch_size, tf.float32)
    #emb_matrix =infer_model.model.embedding_encoder
    #print ("emb_matrix",emb_matrix)
    if num_workers == 1:
        single_worker_inference(
            #emb_matrix,
            infer_model,
            ckpt,
            inference_input_file,
            inference_output_file,
            hparams,
            model_creator)
    else:
        multi_worker_inference(infer_model,
                               ckpt,
                               inference_input_file,
                               inference_output_file,
                               hparams,
                               num_workers=num_workers,
                               jobid=jobid)
コード例 #12
0
    def _createTestInferCheckpoint(self, hparams, name):
        # Prepare
        hparams.vocab_prefix = ("nmt/testdata/test_infer_vocab")
        hparams.src_vocab_file = hparams.vocab_prefix + "." + hparams.src
        hparams.tgt_vocab_file = hparams.vocab_prefix + "." + hparams.tgt
        out_dir = os.path.join(tf.test.get_temp_dir(), name)
        os.makedirs(out_dir)
        hparams.out_dir = out_dir

        # Create check point
        model_creator = inference.get_model_creator(hparams)
        infer_model = model_helper.create_infer_model(model_creator, hparams)
        with self.test_session(graph=infer_model.graph) as sess:
            loaded_model, global_step = model_helper.create_or_load_model(
                infer_model.model, out_dir, sess, "infer_name")
            ckpt_path = loaded_model.saver.save(sess,
                                                os.path.join(
                                                    out_dir, "translate.ckpt"),
                                                global_step=global_step)
        return ckpt_path
コード例 #13
0
def decode(ckpt_path, inference_input_file, inference_output_file, hparams):

    model_creator = get_model_creator(hparams)
    infer_model = model_helper.create_infer_model(model_creator, hparams, None)
    sess, loaded_infer_model = start_sess_and_load_model(
        infer_model, ckpt_path)

    # Read data
    infer_data = load_data(inference_input_file, hparams)
    print("infer_data:", infer_data)

    for i in range(len(infer_data)):

        sess.run(infer_model.iterator.initializer,
                 feed_dict={
                     infer_model.src_placeholder: [infer_data[i]],
                     infer_model.batch_size_placeholder: 1
                 })

        if hparams.task == "joint":
            outputs, intent_pred, src_seq_length, attention_summary = \
                loaded_infer_model.decode(sess)
        elif hparams.task == "intent":
            intent_pred, attention_summary = loaded_infer_model.decode(sess)

        if hparams.infer_mode == "beam_search":
            # get the top translation.
            outputs = outputs[0]

        utils.print_out("             src: %s" % infer_data[i])
        if hparams.task == "joint":
            translation = nmt_utils.get_translation(
                outputs,
                src_seq_length,
                sent_id=0,
                tgt_eos=hparams.eos,
                subword_option=hparams.subword_option)
            utils.print_out(b"   intent (pred): %s" % intent_pred[0])
            utils.print_out(b"     slot (pred): %s\n" % translation)
        elif hparams.task == "intent":
            utils.print_out(b"   intent (pred): %s\n" % intent_pred[0])
コード例 #14
0
def modelInit():

    out_dir = 'tmp/nmt_attention_model'

    print("loading parameters")
    hparams = utils.load_hparams(out_dir)
    ckpt = tf.train.latest_checkpoint(out_dir)

    if not hparams.attention:
        model_creator = nmt_model.Model
    elif hparams.attention_architecture == "standard":
        model_creator = attention_model.AttentionModel
    elif hparams.attention_architecture in ["gnmt", "gnmt_v2"]:
        model_creator = gnmt_model.GNMTModel
    else:
        raise ValueError("Unknown model architecture")

    infer_model = model_helper.create_infer_model(model_creator,
                                                  hparams,
                                                  scope=None)
    return infer_model, ckpt, hparams
コード例 #15
0
    def _load_model(self, checkpoint_path, default_hparams_path,
                    model_hparams_path, source_vocab_path, target_vocab_path):
        hparams = self._create_hparams(default_hparams_path,
                                       model_hparams_path, source_vocab_path,
                                       target_vocab_path)

        model_creator = gnmt_model.GNMTModel
        infer_model = model_helper.create_infer_model(model_creator,
                                                      hparams,
                                                      scope=None)

        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=GPU_MEM_FRAC)
        sess = tf.Session(graph=infer_model.graph,
                          config=tf.ConfigProto(gpu_options=gpu_options,
                                                allow_soft_placement=True))
        with infer_model.graph.as_default():
            nmt_model = model_helper.load_model(infer_model.model,
                                                checkpoint_path, sess, "infer")

        return sess, nmt_model, infer_model, hparams
コード例 #16
0
ファイル: nmt_eval.py プロジェクト: zmxdream/parallax
def eval_fn(hparams, scope=None, target_session=""):
    """Evaluate a translation model."""
    log_device_placement = hparams.log_device_placement
    out_dir = hparams.out_dir
    num_train_steps = hparams.num_train_steps
    steps_per_stats = hparams.steps_per_stats
    steps_per_external_eval = hparams.steps_per_external_eval
    steps_per_eval = 10 * steps_per_stats
    avg_ckpts = hparams.avg_ckpts

    if not steps_per_external_eval:
        steps_per_external_eval = 5 * steps_per_eval

    if not hparams.attention:
        model_creator = nmt_model.Model
    else:  # Attention
        if (hparams.encoder_type == "gnmt"
                or hparams.attention_architecture in ["gnmt", "gnmt_v2"]):
            model_creator = gnmt_model.GNMTModel
        elif hparams.attention_architecture == "standard":
            model_creator = attention_model.AttentionModel
        else:
            raise ValueError("Unknown attention architecture %s" %
                             hparams.attention_architecture)

    eval_model = model_helper.create_eval_model(model_creator, hparams, scope)
    infer_model = model_helper.create_infer_model(model_creator, hparams,
                                                  scope)

    # Preload data for sample decoding.
    dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src)
    dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt)
    sample_src_data = inference.load_data(dev_src_file)
    sample_tgt_data = inference.load_data(dev_tgt_file)

    summary_name = "train_log"
    model_dir = hparams.out_dir

    # Log and output files
    log_file = os.path.join(out_dir, "log_%d" % time.time())
    log_f = tf.gfile.GFile(log_file, mode="a")
    utils.print_out("# log_file=%s" % log_file, log_f)

    # TensorFlow model
    config_proto = utils.get_config_proto(
        log_device_placement=log_device_placement,
        num_intra_threads=hparams.num_intra_threads,
        num_inter_threads=hparams.num_inter_threads)
    eval_sess = tf.Session(target=target_session,
                           config=config_proto,
                           graph=eval_model.graph)
    infer_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=infer_model.graph)

    # First evaluation
    ckpt_size = len(
        tf.train.get_checkpoint_state(model_dir).all_model_checkpoint_paths)
    for ckpt_index in range(ckpt_size):
        train.run_full_eval(model_dir,
                            infer_model,
                            infer_sess,
                            eval_model,
                            eval_sess,
                            hparams,
                            None,
                            sample_src_data,
                            sample_tgt_data,
                            avg_ckpts,
                            ckpt_index=ckpt_index)
コード例 #17
0
ファイル: train.py プロジェクト: ml-lab/Pun-GAN
def train(hparams, scope=None, target_session="", compute_ppl=0):
    """Train a translation model."""
    log_device_placement = hparams.log_device_placement
    out_dir = hparams.out_dir
    num_train_steps = hparams.num_train_steps
    steps_per_stats = hparams.steps_per_stats
    steps_per_external_eval = hparams.steps_per_external_eval
    steps_per_eval = 10 * steps_per_stats
    avg_ckpts = hparams.avg_ckpts

    if not steps_per_external_eval:
        steps_per_external_eval = 5 * steps_per_eval

    if not hparams.attention:  # choose this model
        model_creator = nmt_model.Model
    else:  # Attention
        if (hparams.encoder_type == "gnmt"
                or hparams.attention_architecture in ["gnmt", "gnmt_v2"]):
            model_creator = gnmt_model.GNMTModel
        elif hparams.attention_architecture == "standard":
            model_creator = attention_model.AttentionModel
        else:
            raise ValueError("Unknown attention architecture %s" %
                             hparams.attention_architecture)

    train_model = model_helper.create_train_model(model_creator, hparams,
                                                  scope)
    eval_model = model_helper.create_eval_model(model_creator, hparams, scope)
    infer_model = model_helper.create_infer_model(model_creator, hparams,
                                                  scope)

    # Preload data for sample decoding.
    dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src)
    dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt)
    sample_src_data = inference.load_data(dev_src_file)
    sample_tgt_data = inference.load_data(dev_tgt_file)
    wsd_src_file = "%s" % (hparams.sample_prefix)

    wsd_src_data = inference.load_data(wsd_src_file)

    summary_name = "train_log"
    model_dir = hparams.out_dir

    # Log and output files
    log_file = os.path.join(out_dir, "log_%d" % time.time())
    log_f = tf.gfile.GFile(log_file, mode="a")
    utils.print_out("# log_file=%s" % log_file, log_f)

    # TensorFlow model
    config_proto = utils.get_config_proto(
        log_device_placement=log_device_placement,
        num_intra_threads=hparams.num_intra_threads,
        num_inter_threads=hparams.num_inter_threads)
    train_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=train_model.graph)
    eval_sess = tf.Session(target=target_session,
                           config=config_proto,
                           graph=eval_model.graph)
    infer_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=infer_model.graph)

    with train_model.graph.as_default():
        loaded_train_model, global_step = model_helper.create_or_load_model(
            train_model.model, model_dir, train_sess, "train")

    # Summary writer
    summary_writer = tf.summary.FileWriter(os.path.join(out_dir, summary_name),
                                           train_model.graph)

    # First evaluation
    '''
  run_full_eval(
      model_dir, infer_model, infer_sess,
      eval_model, eval_sess, hparams,
      summary_writer, sample_src_data,
      sample_tgt_data, avg_ckpts)
  '''
    last_stats_step = global_step
    last_eval_step = global_step
    last_external_eval_step = global_step

    # This is the training loop.
    stats, info, start_train_time = before_train(loaded_train_model,
                                                 train_model, train_sess,
                                                 global_step, hparams, log_f)
    end_step = global_step + 100
    while global_step < end_step:  # num_train_steps
        ### Run a step ###
        start_time = time.time()
        try:
            # then forward inference result to WSD, get reward
            step_result = loaded_train_model.train(train_sess)
            # forward reward to placeholder of loaded_train_model, and write a new train function where loss = loss*reward
            hparams.epoch_step += 1
        except tf.errors.OutOfRangeError:
            # Finished going through the training dataset.  Go to next epoch.
            hparams.epoch_step = 0
            utils.print_out(
                "# Finished an epoch, step %d. Perform external evaluation" %
                global_step)

            # run_sample_decode(infer_model, infer_sess, model_dir, hparams,
            #                   summary_writer, sample_src_data, sample_tgt_data)

            # only for pretrain
            # run_external_eval(infer_model, infer_sess, model_dir, hparams,
            #                   summary_writer)

            if avg_ckpts:
                run_avg_external_eval(infer_model, infer_sess, model_dir,
                                      hparams, summary_writer, global_step)

            train_sess.run(train_model.iterator.initializer,
                           feed_dict={train_model.skip_count_placeholder: 0})

            continue

        # Process step_result, accumulate stats, and write summary
        global_step, info["learning_rate"], step_summary = update_stats(
            stats, start_time, step_result, hparams)
        summary_writer.add_summary(step_summary, global_step)
        if compute_ppl:
            run_internal_eval(eval_model, eval_sess, model_dir, hparams,
                              summary_writer)
        # Once in a while, we print statistics.
        if global_step - last_stats_step >= steps_per_stats:
            last_stats_step = global_step
            is_overflow = process_stats(stats, info, global_step,
                                        steps_per_stats, log_f)
            print_step_info("  ", global_step, info,
                            _get_best_results(hparams), log_f)
            if is_overflow:
                break

            # Reset statistics
            stats = init_stats()

        if global_step - last_eval_step >= steps_per_eval:
            last_eval_step = global_step
            utils.print_out("# Save eval, global step %d" % global_step)
            utils.add_summary(summary_writer, global_step, "train_ppl",
                              info["train_ppl"])

            # Save checkpoint
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(out_dir,
                                                       "translate.ckpt"),
                                          global_step=global_step)

            # Evaluate on dev/test
            run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                              summary_writer, sample_src_data, sample_tgt_data)
            run_internal_eval(eval_model, eval_sess, model_dir, hparams,
                              summary_writer)

        if global_step - last_external_eval_step >= steps_per_external_eval:
            last_external_eval_step = global_step

            # Save checkpoint
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(out_dir,
                                                       "translate.ckpt"),
                                          global_step=global_step)
            run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                              summary_writer, sample_src_data, sample_tgt_data)
            run_external_eval(infer_model, infer_sess, model_dir, hparams,
                              summary_writer)

            if avg_ckpts:
                run_avg_external_eval(infer_model, infer_sess, model_dir,
                                      hparams, summary_writer, global_step)

    # Done training
    loaded_train_model.saver.save(train_sess,
                                  os.path.join(out_dir, "translate.ckpt"),
                                  global_step=global_step)
    '''
コード例 #18
0
ファイル: train.py プロジェクト: ml-lab/Pun-GAN
def run_sample_decode_pungan_prepare(hparams, scope=None, target_session=""):
    log_device_placement = hparams.log_device_placement
    out_dir = hparams.out_dir
    if not hparams.attention:  # choose this model
        model_creator = nmt_model.Model
    else:  # Attention
        if (hparams.encoder_type == "gnmt"
                or hparams.attention_architecture in ["gnmt", "gnmt_v2"]):
            model_creator = gnmt_model.GNMTModel
        elif hparams.attention_architecture == "standard":
            model_creator = attention_model.AttentionModel
        else:
            raise ValueError("Unknown attention architecture %s" %
                             hparams.attention_architecture)
    infer_model = model_helper.create_infer_model(model_creator, hparams,
                                                  scope)

    def dealt(input, output):
        with open(input) as f:
            with open(output, 'w') as fw:
                for line in f:
                    l = line.strip().split()
                    l.reverse()
                    sent = ' '.join(l)
                    fw.write(sent + '\n')

    wsd_src_file = "%s" % (hparams.sample_prefix)
    wsd_src_file_new = wsd_src_file + '.new'
    dealt(wsd_src_file, wsd_src_file_new)

    wsd_src_data = inference.load_data(wsd_src_file_new)
    model_dir = hparams.out_dir
    # TensorFlow model
    config_proto = utils.get_config_proto(
        log_device_placement=log_device_placement,
        num_intra_threads=hparams.num_intra_threads,
        num_inter_threads=hparams.num_inter_threads)
    infer_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=infer_model.graph)
    print('len wsd_src_data', len(wsd_src_data))
    eval_result = []
    for i in range(len(wsd_src_data) / 32):
        eval_result += run_sample_decode_pungan(
            infer_model, infer_sess, model_dir, hparams,
            wsd_src_data[i * 32:(i + 1) * 32])
    print('eval_result')
    print(eval_result)
    print(len(eval_result))
    backward_step1_in = []
    with open(PUNGAN_ROOT_PATH +
              '/Pun_Generation/data/1backward/backward_step1.in') as f:
        for line in f:
            backward_step1_in.append(line.strip())

    def wsd_input_format(wsd_src_data, eval_result):
        '''
        test_data[0] {'target_word': u'art#n', 'target_sense': None, 'id': 'senseval2.d000.s000.t000', 'context': ['the', '<target>', 'of', 'change_ringing', 'be', 'peculiar', 'to', 'the', 'english', ',', 'and', ',', 'like', 'most', 'english', 'peculiarity', ',', 'unintelligible', 'to', 'the', 'rest', 'of', 'the', 'world', '.'], 'poss': ['DET', 'NOUN', 'ADP', 'NOUN', 'VERB', 'ADJ', 'PRT', 'DET', 'NOUN', '.', 'CONJ', '.', 'ADP', 'ADJ', 'ADJ', 'NOUN', '.', 'ADJ', 'PRT', 'DET', 'NOUN', 'ADP', 'DET', 'NOUN', '.']}
        '''
        wsd_input = []
        senses_input = []

        for i in range(len(eval_result)):
            block = i / 32
            src_word1, src_word2 = backward_step1_in[
                2 * block], backward_step1_in[2 * block + 1]
            tgt_sent = wsd_src_data[i].decode().encode(
                'utf-8') + ' ' + eval_result[i]
            tgt_word = src_word1

            synset = wn.lemma_from_key(tgt_word).synset()
            s = synset.name()
            target_word = '#'.join(s.split('.')[:2])
            context = tgt_sent.split(' ')

            for j in range(len(context)):
                if context[j] == tgt_word:
                    context[j] = '<target>'
            poss_list = ['.' for _ in range(len(context))]
            tmp_dict = {
                'target_word': target_word,
                'target_sense': None,
                'id': None,
                'context': context,
                'poss': poss_list
            }
            wsd_input.append(tmp_dict)
            senses_input.append((src_word1, src_word2))
        return wsd_input, senses_input

    wsd_input, senses_input = wsd_input_format(wsd_src_data, eval_result)
    print('wsd_input', wsd_input)
    print("len of wsd_input", len(wsd_input))
    return wsd_input, senses_input, wsd_src_data, eval_result
コード例 #19
0
ファイル: train.py プロジェクト: Negahead/chichat
def train(hparams, scope=None, target_session=""):
    """Train a translation model."""
    log_device_placement = hparams.log_device_placement
    out_dir = hparams.out_dir
    num_train_steps = hparams.num_train_steps
    steps_per_stats = hparams.steps_per_stats
    steps_per_external_eval = hparams.steps_per_external_eval
    steps_per_eval = 10 * steps_per_stats
    avg_ckpts = hparams.avg_ckpts

    if not steps_per_external_eval:
        steps_per_external_eval = 5 * steps_per_eval

    # Create model
    model_creator = get_model_creator(hparams)
    train_model = model_helper.create_train_model(model_creator, hparams,
                                                  scope)
    eval_model = model_helper.create_eval_model(model_creator, hparams, scope)
    infer_model = model_helper.create_infer_model(model_creator, hparams,
                                                  scope)

    # Preload data for sample decoding.
    dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src)
    dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt)
    sample_src_data = inference.load_data(dev_src_file)
    sample_tgt_data = inference.load_data(dev_tgt_file)

    summary_name = "train_log"
    model_dir = hparams.out_dir

    # Log and output files
    log_file = os.path.join(out_dir, "log_%d" % time.time())
    log_f = tf.gfile.GFile(log_file, mode="a")
    utils.print_out("# log_file=%s" % log_file, log_f)

    # TensorFlow model
    config_proto = utils.get_config_proto(
        log_device_placement=log_device_placement,
        num_intra_threads=hparams.num_intra_threads,
        num_inter_threads=hparams.num_inter_threads)
    train_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=train_model.graph)
    eval_sess = tf.Session(target=target_session,
                           config=config_proto,
                           graph=eval_model.graph)
    infer_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=infer_model.graph)

    with train_model.graph.as_default():
        loaded_train_model, global_step = model_helper.create_or_load_model(
            train_model.model, model_dir, train_sess, "train")

    # Summary writer
    summary_writer = tf.summary.FileWriter(os.path.join(out_dir, summary_name),
                                           train_model.graph)

    # First evaluation
    run_full_eval(model_dir, infer_model, infer_sess, eval_model, eval_sess,
                  hparams, summary_writer, sample_src_data, sample_tgt_data,
                  avg_ckpts)

    last_stats_step = global_step
    last_eval_step = global_step
    last_external_eval_step = global_step

    # This is the training loop.
    stats, info, start_train_time = before_train(loaded_train_model,
                                                 train_model, train_sess,
                                                 global_step, hparams, log_f)
    while global_step < num_train_steps:
        ### Run a step ###
        start_time = time.time()
        try:
            step_result = loaded_train_model.train(train_sess)
            hparams.epoch_step += 1
        except tf.errors.OutOfRangeError:
            # Finished going through the training dataset.  Go to next epoch.
            hparams.epoch_step = 0
            utils.print_out(
                "# Finished an epoch, step %d. Perform external evaluation" %
                global_step)
            run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                              summary_writer, sample_src_data, sample_tgt_data)
            run_external_eval(infer_model, infer_sess, model_dir, hparams,
                              summary_writer)

            if avg_ckpts:
                run_avg_external_eval(infer_model, infer_sess, model_dir,
                                      hparams, summary_writer, global_step)

            train_sess.run(train_model.iterator.initializer,
                           feed_dict={train_model.skip_count_placeholder: 0})
            continue

        # Process step_result, accumulate stats, and write summary
        global_step, info["learning_rate"], step_summary = update_stats(
            stats, start_time, step_result)
        summary_writer.add_summary(step_summary, global_step)

        # Once in a while, we print statistics.
        if global_step - last_stats_step >= steps_per_stats:
            last_stats_step = global_step
            is_overflow = process_stats(stats, info, global_step,
                                        steps_per_stats, log_f)
            print_step_info("  ", global_step, info, get_best_results(hparams),
                            log_f)
            if is_overflow:
                break

            # Reset statistics
            stats = init_stats()

        if global_step - last_eval_step >= steps_per_eval:
            last_eval_step = global_step
            utils.print_out("# Save eval, global step %d" % global_step)
            add_info_summaries(summary_writer, global_step, info)

            # Save checkpoint
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(out_dir,
                                                       "translate.ckpt"),
                                          global_step=global_step)

            # Evaluate on dev/test
            run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                              summary_writer, sample_src_data, sample_tgt_data)
            run_internal_eval(eval_model, eval_sess, model_dir, hparams,
                              summary_writer)

        if global_step - last_external_eval_step >= steps_per_external_eval:
            last_external_eval_step = global_step

            # Save checkpoint
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(out_dir,
                                                       "translate.ckpt"),
                                          global_step=global_step)
            run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                              summary_writer, sample_src_data, sample_tgt_data)
            run_external_eval(infer_model, infer_sess, model_dir, hparams,
                              summary_writer)

            if avg_ckpts:
                run_avg_external_eval(infer_model, infer_sess, model_dir,
                                      hparams, summary_writer, global_step)

    # Done training
    loaded_train_model.saver.save(train_sess,
                                  os.path.join(out_dir, "translate.ckpt"),
                                  global_step=global_step)

    (result_summary, _, final_eval_metrics) = (run_full_eval(
        model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams,
        summary_writer, sample_src_data, sample_tgt_data, avg_ckpts))
    print_step_info("# Final, ", global_step, info, result_summary, log_f)
    utils.print_time("# Done training!", start_train_time)

    summary_writer.close()

    utils.print_out("# Start evaluating saved best models.")
    for metric in hparams.metrics:
        best_model_dir = getattr(hparams, "best_" + metric + "_dir")
        summary_writer = tf.summary.FileWriter(
            os.path.join(best_model_dir, summary_name), infer_model.graph)
        result_summary, best_global_step, _ = run_full_eval(
            best_model_dir, infer_model, infer_sess, eval_model, eval_sess,
            hparams, summary_writer, sample_src_data, sample_tgt_data)
        print_step_info("# Best %s, " % metric, best_global_step, info,
                        result_summary, log_f)
        summary_writer.close()

        if avg_ckpts:
            best_model_dir = getattr(hparams, "avg_best_" + metric + "_dir")
            summary_writer = tf.summary.FileWriter(
                os.path.join(best_model_dir, summary_name), infer_model.graph)
            result_summary, best_global_step, _ = run_full_eval(
                best_model_dir, infer_model, infer_sess, eval_model, eval_sess,
                hparams, summary_writer, sample_src_data, sample_tgt_data)
            print_step_info("# Averaged Best %s, " % metric, best_global_step,
                            info, result_summary, log_f)
            summary_writer.close()

    return final_eval_metrics, global_step
コード例 #20
0
def train(hparams, scope=None, target_session=""):
  """Train a translation model."""
  log_device_placement = hparams.log_device_placement
  out_dir = hparams.out_dir
  num_train_steps = hparams.num_train_steps
  steps_per_stats = hparams.steps_per_stats
  steps_per_external_eval = hparams.steps_per_external_eval
  steps_per_eval = 10 * steps_per_stats
  if not steps_per_external_eval:
    steps_per_external_eval = 5 * steps_per_eval

  if not hparams.attention:
    model_creator = nmt_model.Model
  elif hparams.attention_architecture == "standard":
    model_creator = attention_model.AttentionModel
  elif hparams.attention_architecture in ["gnmt", "gnmt_v2"]:
    model_creator = gnmt_model.GNMTModel
  else:
    raise ValueError("Unknown model architecture")

  train_model = model_helper.create_train_model(model_creator, hparams, scope)
  eval_model = model_helper.create_eval_model(model_creator, hparams, scope)
  infer_model = model_helper.create_infer_model(model_creator, hparams, scope)

  # Preload data for sample decoding.
  dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src)
  dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt)
  sample_src_data = inference.load_data(dev_src_file)
  sample_tgt_data = inference.load_data(dev_tgt_file)

  summary_name = "train_log"
  model_dir = hparams.out_dir

  # Log and output files
  log_file = os.path.join(out_dir, "log_%d" % time.time())
  log_f = tf.gfile.GFile(log_file, mode="a")
  utils.print_out("# log_file=%s" % log_file, log_f)

  avg_step_time = 0.0

  # TensorFlow model
  config_proto = utils.get_config_proto(
      log_device_placement=log_device_placement)

  train_sess = tf.Session(
      target=target_session, config=config_proto, graph=train_model.graph)
  eval_sess = tf.Session(
      target=target_session, config=config_proto, graph=eval_model.graph)
  infer_sess = tf.Session(
      target=target_session, config=config_proto, graph=infer_model.graph)

  with train_model.graph.as_default():
    loaded_train_model, global_step = model_helper.create_or_load_model(
        train_model.model, model_dir, train_sess, "train")

  # Summary writer
  summary_writer = tf.summary.FileWriter(
      os.path.join(out_dir, summary_name), train_model.graph)

  # First evaluation
  run_full_eval(
      model_dir, infer_model, infer_sess,
      eval_model, eval_sess, hparams,
      summary_writer, sample_src_data,
      sample_tgt_data)

  last_stats_step = global_step
  last_eval_step = global_step
  last_external_eval_step = global_step

  # This is the training loop.
  step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0
  checkpoint_total_count = 0.0
  speed, train_ppl = 0.0, 0.0
  start_train_time = time.time()

  utils.print_out(
      "# Start step %d, lr %g, %s" %
      (global_step, loaded_train_model.learning_rate.eval(session=train_sess),
       time.ctime()),
      log_f)

  # Initialize all of the iterators
  skip_count = hparams.batch_size * hparams.epoch_step
  utils.print_out("# Init train iterator, skipping %d elements" % skip_count)
  train_sess.run(
      train_model.iterator.initializer,
      feed_dict={train_model.skip_count_placeholder: skip_count})

  while global_step < num_train_steps:
    ### Run a step ###
    start_time = time.time()
    try:
      step_result = loaded_train_model.train(train_sess)
      (_, step_loss, step_predict_count, step_summary, global_step,
       step_word_count, batch_size) = step_result
      hparams.epoch_step += 1
    except tf.errors.OutOfRangeError:
      # Finished going through the training dataset.  Go to next epoch.
      hparams.epoch_step = 0
      utils.print_out(
          "# Finished an epoch, step %d. Perform external evaluation" %
          global_step)
      run_sample_decode(infer_model, infer_sess,
                        model_dir, hparams, summary_writer, sample_src_data,
                        sample_tgt_data)
      dev_scores, test_scores, _ = run_external_eval(
          infer_model, infer_sess, model_dir,
          hparams, summary_writer)
      train_sess.run(
          train_model.iterator.initializer,
          feed_dict={train_model.skip_count_placeholder: 0})
      continue

    # Write step summary.
    summary_writer.add_summary(step_summary, global_step)

    # update statistics
    step_time += (time.time() - start_time)

    checkpoint_loss += (step_loss * batch_size)
    checkpoint_predict_count += step_predict_count
    checkpoint_total_count += float(step_word_count)

    # Once in a while, we print statistics.
    if global_step - last_stats_step >= steps_per_stats:
      last_stats_step = global_step

      # Print statistics for the previous epoch.
      avg_step_time = step_time / steps_per_stats
      train_ppl = utils.safe_exp(checkpoint_loss / checkpoint_predict_count)
      speed = checkpoint_total_count / (1000 * step_time)
      utils.print_out(
          "  global step %d lr %g "
          "step-time %.2fs wps %.2fK ppl %.2f %s" %
          (global_step,
           loaded_train_model.learning_rate.eval(session=train_sess),
           avg_step_time, speed, train_ppl, _get_best_results(hparams)),
          log_f)
      if math.isnan(train_ppl):
        break

      # Reset timer and loss.
      step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0
      checkpoint_total_count = 0.0

    if global_step - last_eval_step >= steps_per_eval:
      last_eval_step = global_step

      utils.print_out("# Save eval, global step %d" % global_step)
      utils.add_summary(summary_writer, global_step, "train_ppl", train_ppl)

      # Save checkpoint
      loaded_train_model.saver.save(
          train_sess,
          os.path.join(out_dir, "translate.ckpt"),
          global_step=global_step)

      # Evaluate on dev/test
      run_sample_decode(infer_model, infer_sess,
                        model_dir, hparams, summary_writer, sample_src_data,
                        sample_tgt_data)
      dev_ppl, test_ppl = run_internal_eval(
          eval_model, eval_sess, model_dir, hparams, summary_writer)

    if global_step - last_external_eval_step >= steps_per_external_eval:
      last_external_eval_step = global_step

      # Save checkpoint
      loaded_train_model.saver.save(
          train_sess,
          os.path.join(out_dir, "translate.ckpt"),
          global_step=global_step)
      run_sample_decode(infer_model, infer_sess,
                        model_dir, hparams, summary_writer, sample_src_data,
                        sample_tgt_data)
      dev_scores, test_scores, _ = run_external_eval(
          infer_model, infer_sess, model_dir,
          hparams, summary_writer)

  # Done training
  loaded_train_model.saver.save(
      train_sess,
      os.path.join(out_dir, "translate.ckpt"),
      global_step=global_step)

  result_summary, _, dev_scores, test_scores, dev_ppl, test_ppl = run_full_eval(
      model_dir, infer_model, infer_sess,
      eval_model, eval_sess, hparams,
      summary_writer, sample_src_data,
      sample_tgt_data)
  utils.print_out(
      "# Final, step %d lr %g "
      "step-time %.2f wps %.2fK ppl %.2f, %s, %s" %
      (global_step, loaded_train_model.learning_rate.eval(session=train_sess),
       avg_step_time, speed, train_ppl, result_summary, time.ctime()),
      log_f)
  utils.print_time("# Done training!", start_train_time)

  summary_writer.close()

  utils.print_out("# Start evaluating saved best models.")
  for metric in hparams.metrics:
    best_model_dir = getattr(hparams, "best_" + metric + "_dir")
    summary_writer = tf.summary.FileWriter(
        os.path.join(best_model_dir, summary_name), infer_model.graph)
    result_summary, best_global_step, _, _, _, _ = run_full_eval(
        best_model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams,
        summary_writer, sample_src_data, sample_tgt_data)
    utils.print_out("# Best %s, step %d "
                    "step-time %.2f wps %.2fK, %s, %s" %
                    (metric, best_global_step, avg_step_time, speed,
                     result_summary, time.ctime()), log_f)
    summary_writer.close()

  return (dev_scores, test_scores, dev_ppl, test_ppl, global_step)
コード例 #21
0
ファイル: train.py プロジェクト: kkw0877/A3C-tensorflow
def train(flags):
    """Train the policy gradient model. """
    
    out_dir = flags.out_dir
    num_train_steps = flags.num_train_steps
    steps_per_infer = flags.steps_per_infer
    
    # Create model for train, infer mode
    model_creator = get_model_creator(flags)
    train_model = model_helper.create_train_model(flags, model_creator)
    infer_model = model_helper.create_infer_model(flags, model_creator)

    # TODO. set for distributed training and multi gpu 
    config_proto = tf.ConfigProto(allow_soft_placement=True)
    config_proto.gpu_options.allow_growth = True
    
    # Session for train, infer
    train_sess = tf.Session(
        config=config_proto, graph=train_model.graph)
    infer_sess = tf.Session(
        config=config_proto, graph=infer_model.graph)
    
    # Load the train model if there's the file in the directory
    # otherwise, initialize vars in the train model
    with train_model.graph.as_default():
        loaded_train_model, global_step = model_helper.create_or_load_model(
            out_dir, train_model.model, train_sess) 
    
    # Summary
    train_summary = "train_log"
    infer_summary = "infer_log"

    # Summary writer for train, infer
    train_summary_writer = tf.summary.FileWriter(
        os.path.join(out_dir, train_summary), train_model.graph)
    infer_summary_writer = tf.summary.FileWriter(
        os.path.join(out_dir, infer_summary))
    
    # First evaluation
    run_infer(infer_model, out_dir, infer_sess)

    # Initialize step var
    last_infer_steps = global_step
    
    # Training loop
    while global_step < num_train_steps:
        output_tuple = loaded_train_model.train(train_sess)
        global_step = output_tuple.global_step
        train_summary = output_tuple.train_summary    
    
        # Update train summary
        train_summary_writer.add_summary(train_summary, global_step)
        print('current global_step: {}'.format(global_step))

        # Evaluate the model for steps_per_infer 
        if global_step - last_infer_steps >= steps_per_infer:
            # Save checkpoint
            loaded_train_model.saver.save(train_sess, 
                os.path.join(out_dir, "rl.ckpt"), global_step)
            
            last_infer_steps = global_step
            output_tuple = run_infer(infer_model, out_dir, infer_sess)  
            infer_summary = output_tuple.infer_summary

            # Update infer summary
            infer_summary_writer.add_summary(infer_summary, global_step)
    
    # Done training
    loaded_train_model.saver.save(train_sess, 
        os.path.join(out_dir, "rl.ckpt"), global_step)
    print('Train done')
コード例 #22
0
ファイル: train.py プロジェクト: james-tn/Text_Summarization
def train(hps, scope=None, target_session=""):
    """Train a translation model."""
    log_device_placement = hps.log_device_placement
    out_dir = hps.out_dir
    num_train_steps = hps.num_train_steps
    steps_per_stats = hps.steps_per_stats
    steps_per_external_eval = hps.steps_per_external_eval
    steps_per_eval = 100 * steps_per_stats
    if not steps_per_external_eval:
        steps_per_external_eval = 5 * steps_per_eval

    if hps.attention_architecture == "baseline":
        model_creator = AttentionModel
    else:
        model_creator = AttentionHistoryModel

    train_model = model_helper.create_train_model(model_creator, hps, scope)
    eval_model = model_helper.create_eval_model(model_creator, hps, scope)
    infer_model = model_helper.create_infer_model(model_creator, hps, scope)

    # Preload data for sample decoding.

    article_filenames = []
    abstract_filenames = []
    art_dir = hps.data_dir + '/article'
    abs_dir = hps.data_dir + '/abstract'
    for file in os.listdir(art_dir):
        if file.startswith(hps.dev_prefix):
            article_filenames.append(art_dir + "/" + file)
    for file in os.listdir(abs_dir):
        if file.startswith(hps.dev_prefix):
            abstract_filenames.append(abs_dir + "/" + file)
    # if random_decode:
    #     """if this is a random sampling process during training"""
    decode_id = random.randint(0, len(article_filenames) - 1)
    single_article_file = article_filenames[decode_id]
    single_abstract_file = abstract_filenames[decode_id]

    dev_src_file = single_article_file
    dev_tgt_file = single_abstract_file
    sample_src_data = inference_base_model.load_data(dev_src_file)
    sample_tgt_data = inference_base_model.load_data(dev_tgt_file)

    summary_name = "train_log"
    model_dir = hps.out_dir

    # Log and output files
    log_file = os.path.join(out_dir, "log_%d" % time.time())
    log_f = tf.gfile.GFile(log_file, mode="a")
    utils.print_out("# log_file=%s" % log_file, log_f)

    avg_step_time = 0.0

    # TensorFlow model
    config_proto = utils.get_config_proto(
        log_device_placement=log_device_placement)

    train_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=train_model.graph)
    eval_sess = tf.Session(target=target_session,
                           config=config_proto,
                           graph=eval_model.graph)
    infer_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=infer_model.graph)

    with train_model.graph.as_default():
        loaded_train_model, global_step = model_helper.create_or_load_model(
            train_model.model, model_dir, train_sess, "train")

    # Summary writer
    summary_writer = tf.summary.FileWriter(os.path.join(out_dir, summary_name),
                                           train_model.graph)

    # First evaluation
    # run_full_eval(
    #     model_dir, infer_model, infer_sess,
    #     eval_model, eval_sess, hps,
    #     summary_writer,sample_src_data,sample_tgt_data)

    last_stats_step = global_step
    last_eval_step = global_step
    last_external_eval_step = global_step

    # This is the training loop.
    stats = init_stats()
    speed, train_ppl = 0.0, 0.0
    start_train_time = time.time()

    utils.print_out(
        "# Start step %d, lr %g, %s" %
        (global_step, loaded_train_model.learning_rate.eval(
            session=train_sess), time.ctime()), log_f)

    # Initialize all of the iterators
    skip_count = hps.batch_size * hps.epoch_step
    utils.print_out("# Init train iterator, skipping %d elements" % skip_count)
    train_sess.run(train_model.iterator.initializer,
                   feed_dict={train_model.skip_count_placeholder: skip_count})
    epoch_step = 0
    while global_step < num_train_steps:
        ### Run a step ###
        start_time = time.time()
        try:
            step_result = loaded_train_model.train(train_sess)
            epoch_step += 1
        except tf.errors.OutOfRangeError:
            # Finished going through the training dataset.  Go to next epoch.
            epoch_step = 0
            utils.print_out(
                "# Finished an epoch, step %d. Perform external evaluation" %
                global_step)

            run_sample_decode(infer_model, infer_sess, model_dir, hps,
                              summary_writer, sample_src_data, sample_tgt_data)
            dev_scores, test_scores, _ = run_external_eval(
                infer_model, infer_sess, model_dir, hps, summary_writer)
            train_sess.run(train_model.iterator.initializer,
                           feed_dict={train_model.skip_count_placeholder: 0})
            continue

        # Write step summary and accumulate statistics
        global_step = update_stats(stats, summary_writer, start_time,
                                   step_result)

        # Once in a while, we print statistics.
        if global_step - last_stats_step >= steps_per_stats:
            last_stats_step = global_step
            is_overflow = check_stats(stats, global_step, steps_per_stats, hps,
                                      log_f)
            if is_overflow:
                break

            # Reset statistics
            stats = init_stats()

        if global_step - last_eval_step >= steps_per_eval:
            last_eval_step = global_step

            utils.print_out("# Save eval, global step %d" % global_step)
            utils.add_summary(summary_writer, global_step, "train_ppl",
                              train_ppl)

            # Save checkpoint
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(out_dir,
                                                       "summarized.ckpt"),
                                          global_step=global_step)

            # Evaluate on dev/test
            run_sample_decode(infer_model, infer_sess, model_dir, hps,
                              summary_writer, sample_src_data, sample_tgt_data)
            dev_ppl, test_ppl = run_internal_eval(eval_model, eval_sess,
                                                  model_dir, hps,
                                                  summary_writer)

        if global_step - last_external_eval_step >= steps_per_external_eval:
            last_external_eval_step = global_step

            # Save checkpoint
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(out_dir,
                                                       "summarized.ckpt"),
                                          global_step=global_step)
            run_sample_decode(infer_model, infer_sess, model_dir, hps,
                              summary_writer, sample_src_data, sample_tgt_data)
            dev_scores, test_scores, _ = run_external_eval(
                infer_model, infer_sess, model_dir, hps, summary_writer)

    # Done training
    loaded_train_model.saver.save(train_sess,
                                  os.path.join(out_dir, "summarized.ckpt"),
                                  global_step=global_step)

    result_summary, _, dev_scores, test_scores, dev_ppl, test_ppl = run_full_eval(
        model_dir, infer_model, infer_sess, eval_model, eval_sess, hps,
        summary_writer, sample_src_data, sample_tgt_data)
    utils.print_out(
        "# Final, step %d lr %g "
        "step-time %.2f wps %.2fK ppl %.2f, %s, %s" %
        (global_step,
         loaded_train_model.learning_rate.eval(session=train_sess),
         avg_step_time, speed, train_ppl, result_summary, time.ctime()), log_f)
    utils.print_time("# Done training!", start_train_time)

    summary_writer.close()

    utils.print_out("# Start evaluating saved best models.")
    for metric in hps.metrics:
        best_model_dir = getattr(hps, "best_" + metric + "_dir")
        summary_writer = tf.summary.FileWriter(
            os.path.join(best_model_dir, summary_name), infer_model.graph)
        result_summary, best_global_step, _, _, _, _ = run_full_eval(
            best_model_dir, infer_model, infer_sess, eval_model, eval_sess,
            hps, summary_writer, sample_src_data, sample_tgt_data)
        utils.print_out(
            "# Best %s, step %d "
            "step-time %.2f wps %.2fK, %s, %s" %
            (metric, best_global_step, avg_step_time, speed, result_summary,
             time.ctime()), log_f)
        summary_writer.close()

    return (dev_scores, test_scores, dev_ppl, test_ppl, global_step)
コード例 #23
0
def train(hparams, scope=None):

    model_dir = hparams.out_dir
    avg_ckpts = hparams.avg_ckpts
    steps_per_stats = hparams.steps_per_stats
    steps_per_external_eval = hparams.steps_per_external_eval
    steps_per_eval = 10 * steps_per_stats
    if not steps_per_external_eval:
        steps_per_external_eval = 5 * steps_per_eval
    summary_name = "summary"

    model_creator = gnmt_model.GNMTModel
    train_model = model_helper.create_train_model(model_creator, hparams)
    eval_model = model_helper.create_eval_model(model_creator, hparams)
    infer_model = model_helper.create_infer_model(model_creator, hparams)

    config_proto = tf.ConfigProto()
    config_proto.gpu_options.allow_growth = True
    train_sess = tf.Session(graph=train_model.graph, config=config_proto)
    eval_sess = tf.Session(graph=eval_model.graph, config=config_proto)
    infer_sess = tf.Session(graph=infer_model.graph, config=config_proto)

    with train_model.graph.as_default():
        loaded_train_model, global_step = model_helper.create_or_load_model(
            train_model.model, model_dir, train_sess, "train")

    # Summary writer
    summary_writer = tf.summary.FileWriter(
        os.path.join(model_dir, summary_name), train_model.graph)

    # Preload data for sample decoding.
    dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src)
    dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt)
    sample_src_data = utils.load_data(dev_src_file)
    sample_tgt_data = utils.load_data(dev_tgt_file)

    # First evaluation
    result_summary, _, _ = run_full_eval(model_dir, infer_model, infer_sess,
                                         eval_model, eval_sess, hparams,
                                         summary_writer, sample_src_data,
                                         sample_tgt_data, avg_ckpts)
    utils.log('First evaluation: {}'.format(result_summary))

    last_stats_step = global_step
    last_eval_step = global_step
    last_external_eval_step = global_step

    # This is the training loop.
    stats = init_stats()
    info = {
        "train_ppl": 0.0,
        "speed": 0.0,
        "avg_step_time": 0.0,
        "avg_grad_norm": 0.0,
        "learning_rate":
        loaded_train_model.learning_rate.eval(session=train_sess)
    }
    utils.log("Start step %d, lr %g" % (global_step, info["learning_rate"]))

    # Initialize all of the iterators
    train_sess.run(train_model.iterator.initializer)

    epoch = 1

    while True:
        ### Run a step ###
        start_time = time.time()
        try:
            step_result = loaded_train_model.train(train_sess)
        except tf.errors.OutOfRangeError:
            # Finished going through the training dataset.  Go to next epoch.
            utils.log(
                "Finished epoch %d, step %d. Perform external evaluation" %
                (epoch, global_step))
            run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                              summary_writer, sample_src_data, sample_tgt_data)
            run_external_eval(infer_model, infer_sess, model_dir, hparams,
                              summary_writer)
            if avg_ckpts:
                run_avg_external_eval(infer_model, infer_sess, model_dir,
                                      hparams, summary_writer, global_step)
            train_sess.run(train_model.iterator.initializer)

            if epoch < hparams.epochs:
                epoch += 1
                continue
            else:
                break

        # Process step_result, accumulate stats, and write summary
        global_step, info["learning_rate"], step_summary = update_stats(
            stats, start_time, step_result)
        summary_writer.add_summary(step_summary, global_step)

        # Once in a while, we print statistics.
        if global_step - last_stats_step >= steps_per_stats:
            last_stats_step = global_step

            is_overflow = process_stats(stats, info, global_step,
                                        steps_per_stats)
            print_step_info("  ", global_step, info,
                            "BLEU %.2f" % (hparams.best_bleu, ))
            if is_overflow:
                break

            # Reset statistics
            stats = init_stats()

        if global_step - last_eval_step >= steps_per_eval:
            last_eval_step = global_step

            utils.log("Save eval, global step %d" % (global_step, ))
            utils.add_summary(summary_writer, global_step, "train_ppl",
                              info["train_ppl"])

            # Save checkpoint
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(model_dir,
                                                       "translate.ckpt"),
                                          global_step=global_step)

            # Evaluate on dev/test
            run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                              summary_writer, sample_src_data, sample_tgt_data)
            run_internal_eval(eval_model, eval_sess, model_dir, hparams,
                              summary_writer)

        if global_step - last_external_eval_step >= steps_per_external_eval:
            last_external_eval_step = global_step

            # Save checkpoint
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(model_dir,
                                                       "translate.ckpt"),
                                          global_step=global_step)
            run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                              summary_writer, sample_src_data, sample_tgt_data)
            run_external_eval(infer_model, infer_sess, model_dir, hparams,
                              summary_writer)
            if avg_ckpts:
                run_avg_external_eval(infer_model, infer_sess, model_dir,
                                      hparams, summary_writer, global_step)

    # Done training
    loaded_train_model.saver.save(train_sess,
                                  os.path.join(model_dir, "translate.ckpt"),
                                  global_step=global_step)

    (result_summary, _, final_eval_metrics) = run_full_eval(
        model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams,
        summary_writer, sample_src_data, sample_tgt_data, avg_ckpts)
    print_step_info("Final, ", global_step, info, result_summary)
    utils.log("Done training!")

    summary_writer.close()

    utils.log("Start evaluating saved best models.")
    best_model_dir = hparams.best_bleu_dir
    summary_writer = tf.summary.FileWriter(
        os.path.join(best_model_dir, summary_name), infer_model.graph)
    result_summary, best_global_step, _ = run_full_eval(
        best_model_dir, infer_model, infer_sess, eval_model, eval_sess,
        hparams, summary_writer, sample_src_data, sample_tgt_data)
    print_step_info("Best BLEU, ", best_global_step, info, result_summary)
    summary_writer.close()

    if avg_ckpts:
        best_model_dir = hparams.avg_best_bleu_dir
        summary_writer = tf.summary.FileWriter(
            os.path.join(best_model_dir, summary_name), infer_model.graph)
        result_summary, best_global_step, _ = run_full_eval(
            best_model_dir, infer_model, infer_sess, eval_model, eval_sess,
            hparams, summary_writer, sample_src_data, sample_tgt_data)
        print_step_info("Averaged Best BLEU, ", best_global_step, info,
                        result_summary)
        summary_writer.close()

    return final_eval_metrics, global_step
コード例 #24
0
ファイル: train.py プロジェクト: herbertchen1/MANNs4NMT
def train(hparams, scope=None, target_session=""):
    """Train a translation model."""
    log_device_placement = hparams.log_device_placement
    out_dir = hparams.out_dir
    num_train_steps = hparams.num_train_steps
    steps_per_stats = hparams.steps_per_stats
    steps_per_external_eval = hparams.steps_per_external_eval
    steps_per_eval = 10 * steps_per_stats
    if not steps_per_external_eval:
        steps_per_external_eval = 5 * steps_per_eval

    if not hparams.attention:
        model_creator = nmt_model.Model
    elif hparams.attention_architecture == "standard":
        model_creator = attention_model.AttentionModel
    elif hparams.attention_architecture in ["gnmt", "gnmt_v2"]:
        model_creator = gnmt_model.GNMTModel
    else:
        raise ValueError("Unknown model architecture")

    train_model = model_helper.create_train_model(model_creator, hparams,
                                                  scope)
    eval_model = model_helper.create_eval_model(model_creator, hparams, scope)
    infer_model = model_helper.create_infer_model(model_creator, hparams,
                                                  scope)

    # Preload data for sample decoding.
    dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src)
    dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt)
    sample_src_data = inference.load_data(dev_src_file)
    sample_tgt_data = inference.load_data(dev_tgt_file)

    summary_name = "train_log"
    model_dir = hparams.out_dir

    # Log and output files
    log_file = os.path.join(out_dir, "log_%d" % time.time())
    log_f = tf.gfile.GFile(log_file, mode="w")
    utils.print_out("# log_file=%s" % log_file, log_f)

    avg_step_time = 0.0

    # TensorFlow model
    config_proto = utils.get_config_proto(
        log_device_placement=log_device_placement)

    train_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=train_model.graph)
    eval_sess = tf.Session(target=target_session,
                           config=config_proto,
                           graph=eval_model.graph)
    infer_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=infer_model.graph)

    with train_model.graph.as_default():
        loaded_train_model, global_step = model_helper.create_or_load_model(
            train_model.model, model_dir, train_sess, "train")

    # Summary writer
    summary_writer = tf.summary.FileWriter(os.path.join(out_dir, summary_name),
                                           train_model.graph)

    # First evaluation
    run_full_eval(model_dir, infer_model, infer_sess, eval_model, eval_sess,
                  hparams, summary_writer, sample_src_data, sample_tgt_data)

    last_stats_step = global_step
    last_eval_step = global_step
    last_external_eval_step = global_step

    # This is the training loop.
    step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0
    checkpoint_total_count = 0.0
    speed, train_ppl = 0.0, 0.0
    start_train_time = time.time()

    utils.print_out(
        "# Start step %d, lr %g, %s" %
        (global_step, loaded_train_model.learning_rate.eval(
            session=train_sess), time.ctime()), log_f)

    # Initialize all of the iterators
    skip_count = hparams.batch_size * hparams.epoch_step
    utils.print_out("# Init train iterator, skipping %d elements" % skip_count)

    if hparams.curriculum == 'none':
        train_sess.run(
            train_model.iterator.initializer,
            feed_dict={train_model.skip_count_placeholder: skip_count})
    else:
        if hparams.curriculum == 'predictive_gain':
            exp3s = Exp3S(hparams.num_curriculum_buckets, 0.001, 0, 0.05)
        elif hparams.curriculum == 'look_back_and_forward':
            curriculum_point = 0

        handle = train_model.iterator.handle
        for i in range(hparams.num_curriculum_buckets):
            train_sess.run(
                train_model.iterator.initializer[i].initializer,
                feed_dict={train_model.skip_count_placeholder: skip_count})

        iterator_handles = [
            train_sess.run(
                train_model.iterator.initializer[i].string_handle(),
                feed_dict={train_model.skip_count_placeholder: skip_count})
            for i in range(hparams.num_curriculum_buckets)
        ]

    utils.print_out("Starting training")

    while global_step < num_train_steps:
        ### Run a step ###
        start_time = time.time()
        try:
            if hparams.curriculum != 'none':
                if hparams.curriculum == 'predictive_gain':
                    lesson = exp3s.draw_task()
                elif hparams.curriculum == 'look_back_and_forward':
                    if curriculum_point == hparams.num_curriculum_buckets:
                        lesson = np.random.randint(
                            low=0, high=hparams.num_curriculum_buckets)
                    else:
                        lesson = curriculum_point if np.random.random_sample(
                        ) < 0.8 else np.random.randint(
                            low=0, high=hparams.num_curriculum_buckets)

                step_result = loaded_train_model.train(
                    hparams,
                    train_sess,
                    handle=handle,
                    iterator_handle=iterator_handles[lesson],
                    use_fed_source_placeholder=loaded_train_model.
                    use_fed_source,
                    fed_source_placeholder=loaded_train_model.fed_source)

                (_, step_loss, step_predict_count, step_summary, global_step,
                 step_word_count, batch_size, source) = step_result

                if hparams.curriculum == 'predictive_gain':
                    new_loss = train_sess.run(
                        [loaded_train_model.train_loss],
                        feed_dict={
                            handle: iterator_handles[lesson],
                            loaded_train_model.use_fed_source: True,
                            loaded_train_model.fed_source: source
                        })

                    # new_loss = loaded_train_model.train_loss.eval(
                    #   session=train_sess,
                    #   feed_dict={
                    #     handle: iterator_handles[lesson],
                    #     loaded_train_model.use_fed_source: True,
                    #     loaded_train_model.fed_source: source
                    #   })

                    # utils.print_out("lesson: %s, step loss: %s, new_loss: %s" % (lesson, step_loss, new_loss))
                    # utils.print_out("exp3s dist: %s" % (exp3s.pi, ))

                    curriculum_point_a = lesson * (
                        hparams.src_max_len //
                        hparams.num_curriculum_buckets) + 1
                    curriculum_point_b = (
                        lesson + 1) * (hparams.src_max_len //
                                       hparams.num_curriculum_buckets) + 1

                    v = step_loss - new_loss
                    exp3s.update_w(
                        v,
                        float(curriculum_point_a + curriculum_point_b) / 2.0)
                elif hparams.curriculum == 'look_back_and_forward':
                    utils.print_out("step loss: %s, lesson: %s" %
                                    (step_loss, lesson))
                    curriculum_point_a = curriculum_point * (
                        hparams.src_max_len //
                        hparams.num_curriculum_buckets) + 1
                    curriculum_point_b = (curriculum_point + 1) * (
                        hparams.src_max_len //
                        hparams.num_curriculum_buckets) + 1

                    if step_loss < (hparams.curriculum_progress_loss *
                                    (float(curriculum_point_a +
                                           curriculum_point_b) / 2.0)):
                        curriculum_point += 1
            else:
                step_result = loaded_train_model.train(hparams, train_sess)
                (_, step_loss, step_predict_count, step_summary, global_step,
                 step_word_count, batch_size) = step_result
            hparams.epoch_step += 1
        except tf.errors.OutOfRangeError:
            # Finished going through the training dataset.  Go to next epoch.
            hparams.epoch_step = 0
            # utils.print_out(
            #     "# Finished an epoch, step %d. Perform external evaluation" %
            #     global_step)
            # run_sample_decode(infer_model, infer_sess,
            #                   model_dir, hparams, summary_writer, sample_src_data,
            #                   sample_tgt_data)
            # dev_scores, test_scores, _ = run_external_eval(
            #     infer_model, infer_sess, model_dir,
            #     hparams, summary_writer)
            if hparams.curriculum == 'none':
                train_sess.run(
                    train_model.iterator.initializer,
                    feed_dict={train_model.skip_count_placeholder: 0})
            else:
                train_sess.run(
                    train_model.iterator.initializer[lesson].initializer,
                    feed_dict={train_model.skip_count_placeholder: 0})
            continue

        # Write step summary.
        summary_writer.add_summary(step_summary, global_step)

        # update statistics
        step_time += (time.time() - start_time)

        checkpoint_loss += (step_loss * batch_size)
        checkpoint_predict_count += step_predict_count
        checkpoint_total_count += float(step_word_count)

        # Once in a while, we print statistics.
        if global_step - last_stats_step >= steps_per_stats:
            if hparams.curriculum == 'predictive_gain':
                utils.print_out("lesson: %s, step loss: %s, new_loss: %s" %
                                (lesson, step_loss, new_loss))
                utils.print_out("exp3s dist: %s" % (exp3s.pi, ))

            last_stats_step = global_step

            # Print statistics for the previous epoch.
            avg_step_time = step_time / steps_per_stats
            train_ppl = utils.safe_exp(checkpoint_loss /
                                       checkpoint_predict_count)
            speed = checkpoint_total_count / (1000 * step_time)
            utils.print_out(
                "  global step %d lr %g "
                "step-time %.2fs wps %.2fK ppl %.2f %s" %
                (global_step,
                 loaded_train_model.learning_rate.eval(session=train_sess),
                 avg_step_time, speed, train_ppl, _get_best_results(hparams)),
                log_f)

            if math.isnan(train_ppl):
                break

            # Reset timer and loss.
            step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0
            checkpoint_total_count = 0.0

        if global_step - last_eval_step >= steps_per_eval:
            last_eval_step = global_step

            utils.print_out("# Save eval, global step %d" % global_step)
            utils.add_summary(summary_writer, global_step, "train_ppl",
                              train_ppl)

            # Save checkpoint
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(out_dir,
                                                       "translate.ckpt"),
                                          global_step=global_step)

            # Evaluate on dev/test
            run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                              summary_writer, sample_src_data, sample_tgt_data)
            dev_ppl, test_ppl = run_internal_eval(eval_model, eval_sess,
                                                  model_dir, hparams,
                                                  summary_writer)

            dev_scores, test_scores, _ = run_external_eval(
                infer_model, infer_sess, model_dir, hparams, summary_writer)

        # if global_step - last_external_eval_step >= steps_per_external_eval:
        #   last_external_eval_step = global_step

        #   # Save checkpoint
        #   loaded_train_model.saver.save(
        #       train_sess,
        #       os.path.join(out_dir, "translate.ckpt"),
        #       global_step=global_step)
        #   run_sample_decode(infer_model, infer_sess,
        #                     model_dir, hparams, summary_writer, sample_src_data,
        #                     sample_tgt_data)
        #   dev_scores, test_scores, _ = run_external_eval(
        #       infer_model, infer_sess, model_dir,
        #       hparams, summary_writer)

    # Done training
    loaded_train_model.saver.save(train_sess,
                                  os.path.join(out_dir, "translate.ckpt"),
                                  global_step=global_step)

    result_summary, _, dev_scores, test_scores, dev_ppl, test_ppl = run_full_eval(
        model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams,
        summary_writer, sample_src_data, sample_tgt_data)

    utils.print_out(
        "# Final, step %d lr %g "
        "step-time %.2f wps %.2fK ppl %.2f, %s, %s" %
        (global_step,
         loaded_train_model.learning_rate.eval(session=train_sess),
         avg_step_time, speed, train_ppl, result_summary, time.ctime()), log_f)
    utils.print_time("# Done training!", start_train_time)

    utils.print_out("# Start evaluating saved best models.")
    for metric in hparams.metrics:
        best_model_dir = getattr(hparams, "best_" + metric + "_dir")
        result_summary, best_global_step, _, _, _, _ = run_full_eval(
            best_model_dir, infer_model, infer_sess, eval_model, eval_sess,
            hparams, summary_writer, sample_src_data, sample_tgt_data)
        utils.print_out(
            "# Best %s, step %d "
            "step-time %.2f wps %.2fK, %s, %s" %
            (metric, best_global_step, avg_step_time, speed, result_summary,
             time.ctime()), log_f)

    summary_writer.close()
    return (dev_scores, test_scores, dev_ppl, test_ppl, global_step)