Exemplo n.º 1
0
def single_worker_inference(infer_model, ckpt, inference_input_file,
                            inference_output_file, hparams):
    """Inference with a single worker."""
    output_infer = inference_output_file

    # Read data
    infer_data = load_data(inference_input_file, hparams)

    with tf.Session(graph=infer_model.graph,
                    config=utils.get_config_proto()) as sess:
        loaded_infer_model = model_helper.load_model(infer_model.model, ckpt,
                                                     sess, "infer")
        sess.run(infer_model.iterator.initializer,
                 feed_dict={
                     infer_model.src_placeholder: infer_data,
                     infer_model.batch_size_placeholder:
                     hparams.infer_batch_size
                 })
        # Decode
        utils.print_out("# Start decoding")
        if hparams.inference_indices:
            _decode_inference_indices(
                loaded_infer_model,
                sess,
                output_infer=output_infer,
                output_infer_summary_prefix=output_infer,
                inference_indices=hparams.inference_indices,
                tgt_sos=hparams.sos,
                tgt_eos=hparams.eos,
                bpe_delimiter=hparams.bpe_delimiter)
        else:
            nmt_utils.decode_and_evaluate("infer",
                                          loaded_infer_model,
                                          sess,
                                          output_infer,
                                          ref_file=None,
                                          metrics=hparams.metrics,
                                          bpe_delimiter=hparams.bpe_delimiter,
                                          beam_width=hparams.beam_width,
                                          tgt_sos=hparams.sos,
                                          tgt_eos=hparams.eos)
Exemplo n.º 2
0
def single_worker_inference(sess, infer_model, loaded_infer_model,
                            inference_input_file, inference_output_file,
                            hparams):
    """Inference with a single worker."""
    output_infer = inference_output_file

    # Read data
    infer_data = load_data(inference_input_file, hparams)
    print("infer_data:", infer_data)

    with infer_model.graph.as_default():
        sess.run(infer_model.iterator.initializer,
                 feed_dict={
                     infer_model.src_placeholder: infer_data,
                     infer_model.batch_size_placeholder:
                     hparams.infer_batch_size
                 })
        # Decode
        utils.print_out("# Start decoding")
        if hparams.inference_indices:
            _decode_inference_indices(
                loaded_infer_model,
                sess,
                output_infer=output_infer,
                output_infer_summary_prefix=output_infer,
                inference_indices=hparams.inference_indices,
                tgt_eos=hparams.eos,
                subword_option=hparams.subword_option)
        else:
            nmt_utils.decode_and_evaluate(
                "infer",
                loaded_infer_model,
                sess,
                output_infer,
                ref_file=None,
                metrics=hparams.metrics,
                subword_option=hparams.subword_option,
                beam_width=hparams.beam_width,
                tgt_eos=hparams.eos,
                num_translations_per_input=hparams.num_translations_per_input,
                infer_mode=hparams.infer_mode)
Exemplo n.º 3
0
def _external_eval(model, global_step, sess, hparams, iterator,
                   iterator_feed_dict, tgt_file, lbl_file, label,
                   summary_writer, save_on_best):
    """External evaluation such as BLEU and ROUGE scores."""
    out_dir = hparams.out_dir
    decode = global_step > 0

    if decode:
        utils.print_out("# External evaluation, global step %d" % global_step)

    sess.run(iterator.initializer, feed_dict=iterator_feed_dict)

    slot_output = os.path.join(out_dir, "slot_output_%s" % label)
    intent_output = os.path.join(out_dir, "intent_output_%s" % label)
    scores = nmt_utils.decode_and_evaluate(
        label,
        model,
        sess,
        slot_output,
        intent_output,
        ref_file=tgt_file,
        ref_lbl_file=lbl_file,
        metrics=hparams.metrics,
        subword_option=hparams.subword_option,
        beam_width=hparams.beam_width,
        tgt_eos=hparams.eos,
        task=hparams.task,
        decode=decode,
        infer_mode=hparams.infer_mode)
    # Save on best metrics
    if decode:
        for metric in hparams.metrics:
            best_metric_label = "best_" + metric

            utils.add_summary(summary_writer, global_step,
                              "%s_%s" % (label, metric), scores[metric])
            # metric: larger is better
            if save_on_best and scores[metric] > getattr(
                    hparams, best_metric_label):
                setattr(hparams, best_metric_label, scores[metric])
                model.saver.save(sess,
                                 os.path.join(
                                     getattr(hparams,
                                             best_metric_label + "_dir"),
                                     "translate.ckpt"),
                                 global_step=model.global_step)
        utils.save_hparams(out_dir, hparams)
    return scores
Exemplo n.º 4
0
def multi_worker_inference(infer_model, ckpt, inference_input_file,
                           inference_output_file, hparams, num_workers, jobid):
    """Inference using multiple workers."""
    assert num_workers > 1

    final_output_infer = inference_output_file
    output_infer = "%s_%d" % (inference_output_file, jobid)
    output_infer_done = "%s_done_%d" % (inference_output_file, jobid)

    # Read data
    infer_data = load_data(inference_input_file, hparams)

    # Split data to multiple workers
    total_load = len(infer_data)
    load_per_worker = int((total_load - 1) / num_workers) + 1
    start_position = jobid * load_per_worker
    end_position = min(start_position + load_per_worker, total_load)
    infer_data = infer_data[start_position:end_position]

    with tf.Session(graph=infer_model.graph,
                    config=utils.get_config_proto()) as sess:
        loaded_infer_model = model_helper.load_model(infer_model.model, ckpt,
                                                     sess, "infer")
        sess.run(
            infer_model.iterator.initializer, {
                infer_model.src_placeholder: infer_data,
                infer_model.batch_size_placeholder: hparams.infer_batch_size
            })
        # Decode
        utils.print_out("# Start decoding")
        nmt_utils.decode_and_evaluate("infer",
                                      loaded_infer_model,
                                      sess,
                                      output_infer,
                                      ref_file=None,
                                      metrics=hparams.metrics,
                                      bpe_delimiter=hparams.bpe_delimiter,
                                      beam_width=hparams.beam_width,
                                      tgt_eos=hparams.eos)

        # Change file name to indicate the file writing is completed.
        tf.gfile.Rename(output_infer, output_infer_done, overwrite=True)

        # Job 0 is responsible for the clean up.
        if jobid != 0: return

        # Now write all translations
        with codecs.getwriter("utf-8")(tf.gfile.GFile(final_output_infer,
                                                      mode="wb")) as final_f:
            for worker_id in range(num_workers):
                worker_infer_done = "%s_done_%d" % (inference_output_file,
                                                    worker_id)
                while not tf.gfile.Exists(worker_infer_done):
                    utils.print_out("  waitting job %d to complete." %
                                    worker_id)
                    time.sleep(10)

                with codecs.getreader("utf-8")(tf.gfile.GFile(
                        worker_infer_done, mode="rb")) as f:
                    for translation in f:
                        final_f.write("%s" % translation)
                tf.gfile.Remove(worker_infer_done)
    # Get input and output and tensors/ops for inference.
    src_vocab_placeholder = graph.get_tensor_by_name('source_vocab_file:0')
    tgt_vocab_placeholder = graph.get_tensor_by_name('target_vocab_file:0')
    src_data_placeholder = graph.get_tensor_by_name('source_data:0')
    batch_size_placeholder = graph.get_tensor_by_name('batch_size:0')

    tables_initializer = graph.get_operation_by_name('init_all_tables')
    iterator_initilizer = graph.get_operation_by_name('MakeIterator')
    sample_words_tensor = graph.get_tensor_by_name('hash_table_Lookup_1/LookupTableFindV2:0')

  # Create a session with imported graph.
  config_proto = tf.compat.v1.ConfigProto(allow_soft_placement=True,
      intra_op_parallelism_threads = args.num_intra_threads,
      inter_op_parallelism_threads = args.num_inter_threads)
  sess = tf.compat.v1.Session(graph=graph, config=config_proto)

  # Read source data.
  src_data = read_source_sentences(args.inference_input_file)

  # Initialize vocabulary tables and source data iterator.
  sess.run(tables_initializer, feed_dict={
      src_vocab_placeholder: create_new_vocab_file(args.src_vocab_file),
      tgt_vocab_placeholder: create_new_vocab_file(args.tgt_vocab_file)})
  sess.run(iterator_initilizer, feed_dict={
      src_data_placeholder: src_data,
      batch_size_placeholder: args.batch_size})

  # Decode
  decode_and_evaluate(args.run, sess, sample_words_tensor, inference_output_file,
                      args.inference_ref_file)