Exemplo n.º 1
0
def single_worker_inference(  #emb_matrix,
        infer_model, ckpt, inference_input_file, inference_output_file,
        hparams, model_creator):
    """Inference with a single worker."""
    output_infer = inference_output_file

    # Read data
    infer_data = load_data(inference_input_file, hparams)
    #saver = tf.train.Saver()
    with tf.Session(graph=infer_model.graph,
                    config=utils.get_config_proto()) as sess:
        loaded_infer_model = model_helper.load_model(infer_model.model, ckpt,
                                                     sess, "infer")
        sess.run(infer_model.iterator.initializer,
                 feed_dict={
                     infer_model.src_placeholder: infer_data,
                     infer_model.batch_size_placeholder:
                     hparams.infer_batch_size
                 })
        #sess.run(model_creator._build_decoder.eval())
        # Decode
        #saver = tf.train.Saver()
        #emb=sess.run(emb_matrix)
        #fw=open('/home/yuzw/pun/nmt/inference/embedding_ds','w+')
        #fw.write('\n'.join(
        #           [' '.join([str(u) for u in e]) for e in emb]))
        #print("emb=sess.run(emb_matrix)",emb)
        #save_path = saver.save(sess, "/home/yuzw/pun/nmt/inference/emb.npz")
        #print("Model saved in path: %s" % save_path)
        utils.print_out("# Start decoding single_worker_inference")
        if hparams.inference_indices:
            _decode_inference_indices(
                loaded_infer_model,
                sess,
                output_infer=output_infer,
                output_infer_summary_prefix=output_infer,
                inference_indices=hparams.inference_indices,
                tgt_eos=hparams.eos,
                subword_option=hparams.subword_option)

        else:

            nmt_utils.decode_and_evaluate(
                "infer",
                loaded_infer_model,
                sess,
                output_infer,
                ref_file=None,
                metrics=hparams.metrics,
                subword_option=hparams.subword_option,
                beam_width=hparams.beam_width,
                tgt_eos=hparams.eos,
                num_translations_per_input=hparams.num_translations_per_input)
Exemplo n.º 2
0
def _external_eval(model,
                   global_step,
                   sess,
                   hparams,
                   iterator,
                   iterator_feed_dict,
                   tgt_file,
                   label,
                   summary_writer,
                   save_on_best,
                   avg_ckpts=False):
    """External evaluation such as BLEU and ROUGE scores."""
    out_dir = hparams.out_dir
    decode = global_step > 0

    if avg_ckpts:
        label = "avg_" + label

    if decode:
        utils.print_out("# External evaluation, global step %d" % global_step)

    sess.run(iterator.initializer, feed_dict=iterator_feed_dict)

    output = os.path.join(out_dir, "output_%s" % label)
    scores = nmt_utils.decode_and_evaluate(
        label,
        model,
        sess,
        output,
        ref_file=tgt_file,
        metrics=hparams.metrics,
        subword_option=hparams.subword_option,
        beam_width=hparams.beam_width,
        tgt_eos=hparams.eos,
        decode=decode,
        infer_mode=hparams.infer_mode)
    # Save on best metrics
    if decode:
        for metric in hparams.metrics:
            if avg_ckpts:
                best_metric_label = "avg_best_" + metric
            else:
                best_metric_label = "best_" + metric

            utils.add_summary(summary_writer, global_step,
                              "%s_%s" % (label, metric), scores[metric])
            # metric: larger is better
            if save_on_best and scores[metric] > getattr(
                    hparams, best_metric_label):
                setattr(hparams, best_metric_label, scores[metric])
                model.saver.save(sess,
                                 os.path.join(
                                     getattr(hparams,
                                             best_metric_label + "_dir"),
                                     "translate.ckpt"),
                                 global_step=model.global_step)
        utils.save_hparams(out_dir, hparams)
    return scores
Exemplo n.º 3
0
def single_worker_inference(sess,
                            infer_model,
                            loaded_infer_model,
                            inference_input_file,
                            inference_output_file,
                            hparams):
  """Inference with a single worker."""
  output_infer = inference_output_file

  # Read data
  infer_data = load_data(inference_input_file, hparams)

  with infer_model.graph.as_default():
    sess.run(
        infer_model.iterator.initializer,
        feed_dict={
            infer_model.src_placeholder: infer_data,
            infer_model.batch_size_placeholder: hparams.infer_batch_size
        })
    # Decode
    utils.print_out("# Start decoding")
    if hparams.inference_indices:
      _decode_inference_indices(
          loaded_infer_model,
          sess,
          output_infer=output_infer,
          output_infer_summary_prefix=output_infer,
          inference_indices=hparams.inference_indices,
          tgt_eos=hparams.eos,
          subword_option=hparams.subword_option)
    else:
      nmt_utils.decode_and_evaluate(
          "infer",
          loaded_infer_model,
          sess,
          output_infer,
          ref_file=None,
          metrics=hparams.metrics,
          subword_option=hparams.subword_option,
          beam_width=hparams.beam_width,
          tgt_eos=hparams.eos,
          num_translations_per_input=hparams.num_translations_per_input,
          infer_mode=hparams.infer_mode)
Exemplo n.º 4
0
def single_worker_inference(infer_model, ckpt, inference_input_file,
                            inference_output_file, hparams):
    """Inference with a single worker."""
    output_infer = inference_output_file

    # Read data
    infer_data = load_data(inference_input_file, hparams)

    with tf.Session(graph=infer_model.graph,
                    config=utils.get_config_proto()) as sess:
        loaded_infer_model = model_helper.load_model(infer_model.model, ckpt,
                                                     sess, "infer")
        sess.run(infer_model.iterator.initializer,
                 feed_dict={
                     infer_model.src_placeholder: infer_data,
                     infer_model.batch_size_placeholder:
                     hparams.infer_batch_size
                 })
        # Decode
        utils.print_out("# Start decoding")
        if hparams.inference_indices:
            _decode_inference_indices(
                loaded_infer_model,
                sess,
                output_infer=output_infer,
                output_infer_summary_prefix=output_infer,
                inference_indices=hparams.inference_indices,
                tgt_eos=hparams.eos,
                bpe_delimiter=hparams.bpe_delimiter)
        else:
            nmt_utils.decode_and_evaluate("infer",
                                          loaded_infer_model,
                                          sess,
                                          output_infer,
                                          ref_file=None,
                                          metrics=hparams.metrics,
                                          bpe_delimiter=hparams.bpe_delimiter,
                                          beam_width=hparams.beam_width,
                                          tgt_eos=hparams.eos)
def single_worker_inference(infer_model, ckpt, inference_input_file,
                            inference_output_file, hparams):
    output_infer = inference_output_file

    infer_data = load_data(inference_input_file, hparams)

    with tf.Session(graph=infer_model.graph,
                    config=utils.get_config_proto()) as sess:
        loaded_infer_model = model_util.load_model(infer_model.model, ckpt,
                                                   sess, "infer")
        sess.run(infer_model.iterator.initializer,
                 feed_dict={
                     infer_model.src_placeholder: infer_data,
                     infer_model.batch_size_placeholder:
                     hparams.infer_batch_size
                 })
        utils.print_out("# Start decoding")
        if hparams.inference_indices:
            _decode_inference_indices(
                loaded_infer_model,
                sess,
                output_infer=output_infer,
                output_infer_summary_prefix=output_infer,
                inference_indices=hparams.inference_indices,
                tgt_eos=hparams.eos,
                subword_option=hparams.subword_option)
        else:
            nmt_utils.decode_and_evaluate(
                "infer",
                loaded_infer_model,
                sess,
                output_infer,
                ref_file=None,
                metrics=hparams.metrics,
                subword_option=hparams.subword_option,
                beam_width=hparams.beam_width,
                tgt_eos=hparams.eos,
                num_translations_per_input=hparams.num_translations_per_input)
Exemplo n.º 6
0
def multi_worker_inference(infer_model, ckpt, inference_input_file,
                           inference_output_file, hparams, num_workers, jobid):
    """Inference using multiple workers."""
    assert num_workers > 1

    final_output_infer = inference_output_file
    output_infer = "%s_%d" % (inference_output_file, jobid)
    output_infer_done = "%s_done_%d" % (inference_output_file, jobid)

    # Read data
    infer_data = load_data(inference_input_file, hparams)

    # Split data to multiple workers
    total_load = len(infer_data)
    load_per_worker = int((total_load - 1) / num_workers) + 1
    start_position = jobid * load_per_worker
    end_position = min(start_position + load_per_worker, total_load)
    infer_data = infer_data[start_position:end_position]

    with tf.Session(graph=infer_model.graph,
                    config=utils.get_config_proto()) as sess:

        loaded_infer_model = model_helper.load_model(infer_model.model, ckpt,
                                                     sess, "infer")
        sess.run(infer_model.iterator.initializer,
                 {infer_model.src_placeholder: infer_data})

        # Decode
        utils.print_out("# Start decoding")
        nmt_utils.decode_and_evaluate("infer",
                                      loaded_infer_model,
                                      sess,
                                      output_infer,
                                      ref_file=None,
                                      metrics=hparams.metrics,
                                      bpe_delimiter=hparams.bpe_delimiter,
                                      beam_width=hparams.beam_width,
                                      tgt_eos=hparams.eos)

        # Change file name to indicate the file writing is completed.
        tf.gfile.Rename(output_infer, output_infer_done, overwrite=True)

        # Job 0 is responsible for the clean up.
        if jobid != 0: return

        # Now write all translations
        with codecs.getwriter("utf-8")(tf.gfile.GFile(final_output_infer,
                                                      mode="wb")) as final_f:
            for worker_id in range(num_workers):
                worker_infer_done = "%s_done_%d" % (inference_output_file,
                                                    worker_id)
                while not tf.gfile.Exists(worker_infer_done):
                    utils.print_out("  waitting job %d to complete." %
                                    worker_id)
                    time.sleep(10)

                with codecs.getreader("utf-8")(tf.gfile.GFile(
                        worker_infer_done, mode="rb")) as f:
                    for translation in f:
                        final_f.write("%s" % translation)
                tf.gfile.Remove(worker_infer_done)
def multi_worker_inference(infer_model, ckpt, inference_input_file,
                           inference_output_file, hparams, num_workers,
                           job_id):
    assert num_workers > 1

    final_output_infer = inference_output_file
    output_infer = "%s_%d" % (inference_output_file, job_id)
    output_infer_done = "%s_done_%d" % (inference_output_file, job_id)

    infer_data = load_data(inference_input_file, hparams)

    total_load = len(infer_data)
    load_per_worker = int((total_load - 1) / num_workers) + 1
    start_position = job_id * load_per_worker
    end_position = min(start_position + load_per_worker, total_load)
    infer_data = infer_data[start_position:end_position]

    with tf.Session(graph=infer_model.graph,
                    config=utils.get_config_proto()) as sess:
        loaded_infer_model = model_util.load_model(infer_model.model, ckpt,
                                                   sess, "infer")
        sess.run(infer_model.iterator.initializer,
                 feed_dict={
                     infer_model.src_placeholder: infer_data,
                     infer_model.batch_size_placeholder:
                     hparams.infer_batch_size
                 })
        utils.print_out("# Start decoding")
        nmt_utils.decode_and_evaluate(
            "infer",
            loaded_infer_model,
            sess,
            output_infer,
            ref_file=None,
            metrics=hparams.metrics,
            subword_option=hparams.subword_option,
            beam_width=hparams.beam_width,
            tgt_eos=hparams.eos,
            num_translations_per_input=hparams.num_translations_per_input)

        tf.gfile.Rename(output_infer, output_infer_done, overwrite=True)

        if job_id != 0: return

        with codecs.getwriter("utf-8")(tf.gfile.GFile(final_output_infer,
                                                      mode="wb")) as final_f:
            for worker_id in range(num_workers):
                worker_infer_done = "%s_done_%d" % (inference_output_file,
                                                    worker_id)
                while not tf.gfile.Exists(worker_infer_done):
                    utils.print_out("  waitting job %d to complete." %
                                    worker_id)
                    time.sleep(10)

                with codecs.getreader("utf-8")(tf.gfile.GFile(
                        worker_infer_done, mode="rb")) as f:
                    for translation in f:
                        final_f.write("%s" % translation)

            for worker_id in range(num_workers):
                worker_infer_done = "%s_done_%d" % (inference_output_file,
                                                    worker_id)
                tf.gfile.Remove(worker_infer_done)
Exemplo n.º 8
0
def single_worker_inference(sess, infer_model, loaded_infer_model,
                            inference_input_file, inference_output_file,
                            hparams):
    """Inference with a single worker."""
    output_infer = inference_output_file

    # Read data
    infer_data = load_data(inference_input_file, hparams)
    infer_data_feed = infer_data

    # sort the input file if no hparams.inference_indices is defined
    index_pair = {}
    new_input = []
    if hparams.inference_indices is None:
        start_time = time.time()
        input_length = [(len(line.split()), i)
                        for i, line in enumerate(infer_data)]
        sorted_input_bylens = sorted(input_length)
        for ni, (_, oi) in enumerate(sorted_input_bylens):
            new_input.append(infer_data[oi])
            index_pair[oi] = ni
        infer_data_feed = new_input

    with infer_model.graph.as_default():
        sess.run(infer_model.iterator.initializer,
                 feed_dict={
                     infer_model.src_placeholder: infer_data_feed,
                     infer_model.batch_size_placeholder:
                     hparams.infer_batch_size
                 })
        # Decode
        utils.print_out("# Start decoding")
        if hparams.inference_indices:
            _decode_inference_indices(
                loaded_infer_model,
                sess,
                output_infer=output_infer,
                output_infer_summary_prefix=output_infer,
                inference_indices=hparams.inference_indices,
                tgt_eos=hparams.eos,
                subword_option=hparams.subword_option)
        else:
            _, end_time, num_sentences = nmt_utils.decode_and_evaluate(
                "infer",
                loaded_infer_model,
                sess,
                output_infer,
                ref_file=None,
                metrics=hparams.metrics,
                subword_option=hparams.subword_option,
                beam_width=hparams.beam_width,
                tgt_eos=hparams.eos,
                num_translations_per_input=hparams.num_translations_per_input,
                infer_mode=hparams.infer_mode,
                index_pair=index_pair)
            duration = end_time - start_time
            if hparams.infer_batch_size is 1:
                print("  The latency of the model is %.4f ms/sentences" %
                      (1000 * duration / num_sentences))
            else:
                print("  The throughput of the model is %.4f sentences/s" %
                      (num_sentences / duration))
Exemplo n.º 9
0
def _external_eval(model, global_step, sess, hparams, iterator,
                   iterator_feed_dict, tgt_file, label, summary_writer,
                   save_on_best):
    """External evaluation such as BLEU and ROUGE scores."""
    out_dir = hparams.out_dir
    decode = global_step > 0
    if decode:
        utils.print_out("# External evaluation, global step %d" % global_step)

    sess.run(iterator.initializer, feed_dict=iterator_feed_dict)

    output = os.path.join(out_dir, "output_%s" % label)

    scores = nmt_utils.decode_and_evaluate(label,
                                           model,
                                           sess,
                                           output,
                                           ref_file=tgt_file,
                                           metrics=hparams.metrics,
                                           bpe_delimiter=hparams.bpe_delimiter,
                                           beam_width=hparams.beam_width,
                                           tgt_eos=hparams.eos,
                                           decode=decode)
    # Save on best metrics
    if decode:
        for metric in hparams.metrics:
            utils.add_summary(summary_writer, global_step,
                              "%s_%s" % (label, metric), scores[metric])
            # metric: larger is better
            # if save_on_best and scores[metric] > getattr(hparams, "best_" + metric):
            with open("./tmp/nmt_model/score", 'w+') as resu:
                resu.write(str(global_step) + ":" + str(scores[metric]) + "\n")
            if save_on_best and scores[metric] > getattr(hparams,
                                                         "top_score")[0]:
                new_top_score = []
                new_top_score_name = []
                isTopScore = True
                for score, name in zip(getattr(hparams, "top_score"),
                                       getattr(hparams, "top_score_name")):
                    if scores[metric] < score and isTopScore:
                        new_top_score.append(scores[metric])
                        new_top_score_name.append(str(global_step))
                        isTopScore = False
                    new_top_score.append(score)
                    new_top_score_name.append(name)
                if isTopScore:
                    new_top_score.append(scores[metric])
                    new_top_score_name.append(str(global_step))
                setattr(hparams, "top_score", new_top_score[1:])
                setattr(hparams, "top_score_name", new_top_score_name[1:])
                setattr(hparams, "best_" + metric,
                        new_top_score[len(new_top_score) - 1])
                model.saver.save(sess,
                                 os.path.join(
                                     getattr(hparams,
                                             "best_" + metric + "_dir"),
                                     "translate.ckpt"),
                                 global_step=model.global_step)
                if new_top_score[0] != 0:
                    os.system('rm ' +
                              getattr(hparams, "best_" + metric + "_dir") +
                              '/translate.ckpt-' + new_top_score_name[0] + '*')
        utils.save_hparams(out_dir, hparams)
    return scores