Exemplo n.º 1
0
def _decode_inference_indices(model, sess, output_infer,
                              output_infer_summary_prefix, inference_indices,
                              tgt_eos, subword_option):
    """Decoding only a specific set of sentences."""
    utils.print_out("  decoding to output %s , num sents %d." %
                    (output_infer, len(inference_indices)))
    start_time = time.time()
    with codecs.getwriter("utf-8")(tf.gfile.GFile(output_infer,
                                                  mode="wb")) as trans_f:
        trans_f.write("")  # Write empty string to ensure file is created.
        for decode_id in inference_indices:
            nmt_outputs, infer_summary = model.decode(sess)

            # get text translation
            assert nmt_outputs.shape[0] == 1
            translation = nmt_utils.get_translation(
                nmt_outputs,
                sent_id=0,
                tgt_eos=tgt_eos,
                subword_option=subword_option)

            if infer_summary is not None:  # Attention models
                image_file = output_infer_summary_prefix + str(
                    decode_id) + ".png"
                utils.print_out("  save attention image to %s*" % image_file)
                image_summ = tf.Summary()
                image_summ.ParseFromString(infer_summary)
                with tf.gfile.GFile(image_file, mode="w") as img_f:
                    img_f.write(image_summ.value[0].image.encoded_image_string)
            trans_f.write("%s\n" % translation)
            utils.print_out(translation + b"\n")
    utils.print_time("  done", start_time)
Exemplo n.º 2
0
def _sample_decode(model, global_step, sess, hparams, iterator, src_data,
                   tgt_data, iterator_src_placeholder,
                   iterator_batch_size_placeholder, summary_writer):
    """Pick a sentence and decode."""
    decode_id = random.randint(0, len(src_data) - 1)
    utils.print_out("  # %d" % decode_id)

    iterator_feed_dict = {
        iterator_src_placeholder: [src_data[decode_id]],
        iterator_batch_size_placeholder: 1,
    }
    sess.run(iterator.initializer, feed_dict=iterator_feed_dict)

    infer_logits, nmt_outputs, attention_summary = model.decode(sess)
    if hparams.beam_width > 0:
        # get the top translation.
        nmt_outputs = nmt_outputs[0]

    translation = nmt_utils.get_translation(
        nmt_outputs,
        infer_logits,
        sent_id=0,
        tgt_eos=hparams.eos,
        subword_option=hparams.subword_option)

    utils.print_out("    src: %s" % src_data[decode_id])
    utils.print_out("    ref: %s" % tgt_data[decode_id])
    utils.print_out(b"    nmt: " + translation)

    # Summary
    if attention_summary is not None:
        summary_writer.add_summary(attention_summary, global_step)
Exemplo n.º 3
0
    def get_answer(self, question, id_spk):
        infer_data = [clean_text(question)]
        infer_model = self.infer_model
        with tf.Session(
                graph=infer_model.graph,
                config=utils.get_config_proto()) as sess:
            loaded_infer_model = model_helper.load_model(
                    infer_model.model, self.ckpt, sess, "infer")
            sess.run(
                    infer_model.iterator.initializer,
                    feed_dict={
                        infer_model.src_placeholder: infer_data,
                        infer_model.batch_size_placeholder: 1,
                        infer_model.src_speaker_placeholder: id_spk,
                        infer_model.tgt_speaker_placeholder: self.id
                        })

            nmt_outputs, _ = loaded_infer_model.decode(sess)
            translation = []
            for beam_id in range(self.num_translations_per_input):

                # Set set_id to 0 because batch_size of 1
                translation.append(nmt_utils.get_translation(
                        nmt_outputs=nmt_outputs[beam_id],
                        sent_id=0,
                        tgt_eos=self.hparams.eos,
                        subword_option=self.hparams.subword_option))

        return translation
def get_metric(hparams, predictions, current_step):
    """Run inference and compute metric."""
    predicted_ids = []
    for prediction in predictions:
        predicted_ids.append(prediction["predictions"])

    mlperf_log.gnmt_print(key=mlperf_log.EVAL_SIZE,
                          value=hparams.examples_to_infer)
    if hparams.examples_to_infer < len(predicted_ids):
        predicted_ids = predicted_ids[0:hparams.examples_to_infer]
    translations = _convert_ids_to_strings(hparams.tgt_vocab_file,
                                           predicted_ids)

    trans_file = os.path.join(
        hparams.out_dir, "newstest2014_out_{}.tok.de".format(current_step))
    trans_dir = os.path.dirname(trans_file)
    if not tf.gfile.Exists(trans_dir):
        tf.gfile.MakeDirs(trans_dir)
    tf.logging.info("Writing to file %s" % trans_file)
    with codecs.getwriter("utf-8")(tf.gfile.GFile(trans_file,
                                                  mode="wb")) as trans_f:
        trans_f.write("")  # Write empty string to ensure file is created.
        for translation in translations:
            sentence = nmt_utils.get_translation(
                translation,
                tgt_eos=hparams.eos,
                subword_option=hparams.subword_option)
            trans_f.write((sentence + b"\n").decode("utf-8"))

    # Evaluation
    output_dir = os.path.join(hparams.out_dir,
                              "eval_{}".format(hparams.test_year))
    tf.gfile.MakeDirs(output_dir)
    summary_writer = tf.summary.FileWriter(output_dir)

    ref_file = "%s.%s" % (hparams.test_prefix, hparams.tgt)

    metric = "bleu"
    if hparams.use_borg:
        score = evaluation_utils.evaluate(ref_file, trans_file, metric,
                                          hparams.subword_option)
    else:
        score = get_sacrebleu(trans_file, hparams.detokenizer_file,
                              hparams.test_year)
    with tf.Graph().as_default():
        summaries = []
        summaries.append(tf.Summary.Value(tag=metric, simple_value=score))
    tf_summary = tf.Summary(value=list(summaries))
    summary_writer.add_summary(tf_summary, current_step)

    with tf.gfile.Open(os.path.join(output_dir, 'bleu'), 'w') as f:
        f.write('{}\n'.format(score))

    misc_utils.print_out("  %s: %.1f" % (metric, score))

    summary_writer.close()
    return score
Exemplo n.º 5
0
def mytest_infer_interator():
    src_dataset = tf.data.TextLineDataset(hparam.train_src)
    myinput = get_infer_iterator(src_dataset, hparam)
    ss = myinput.reverse_table.lookup(myinput.src)
    with tf.Session() as sess:
        sess.run(tf.tables_initializer())
        sess.run(myinput.initializer)
        for i in range(5):
            try:
                _src, _src_seq_len, cc = sess.run(
                    [myinput.src, myinput.src_seq_len] + [ss])
                print('src', _src)

                print('src_seq_len', _src_seq_len)
                print('reverce')
                for i, c in enumerate(cc):
                    print(get_translation(cc, i, hparam.EOS))
            except tf.errors.OutOfRangeError:
                print('xxxxxxxxxxxxxxx')
                sess.run(myinput.initializer)
    def translate(self, sentence_id_list, tpu_id=0):
        nmt_outputs = self.backends[tpu_id].predict(
            np.take(self.sources, sentence_id_list, 0),
            np.take(self.seq_lens, sentence_id_list, 0))[0]

        batch_size = nmt_outputs.shape[0]

        translation = []
        for decoded_id in range(batch_size):
            translation += [
                nmt_utils.get_translation(
                    nmt_outputs,
                    decoded_id,
                    tgt_eos=self.hparams.eos,
                    subword_option=self.hparams.subword_option)
            ]

        # Keeping track of how many translations happened
        self.count += len(translation)

        return translation
Exemplo n.º 7
0
    def infer(self, sess):
        '''
        对输入进行翻译,返回翻译的结果
        
        :param sess: 
        :return: 
        '''
        sess.run(self._batchInput.initializer)
        translation = []
        while True:
            try:
                _translation = sess.run(self._result)
                if self._isBeam:
                    #选择beam search得分最高的
                    _translation = _translation[:, :, 0]

                for i, _ in enumerate(_translation):
                    translation.append(
                        get_translation(_translation, i, self.EOS,
                                        self._subword))

            except tf.errors.OutOfRangeError:
                break
        return InferOutput(translation)
Exemplo n.º 8
0
def get_metrics(hparams, model_fn, ckpt=None, only_translate=False):
    """Run inference and compute metrics."""
    pred_estimator = tf.estimator.Estimator(model_fn=model_fn,
                                            model_dir=hparams.output_dir)

    benchmark_hook = BenchmarkHook(hparams.infer_batch_size)

    predictions = pred_estimator.predict(make_input_fn(
        hparams, tf.contrib.learn.ModeKeys.INFER),
                                         checkpoint_path=ckpt,
                                         hooks=[benchmark_hook])
    translations = []
    output_tokens = []
    beam_id = 0
    for prediction in predictions:
        # get the top translation.
        if beam_id == 0:
            for sent_id in range(hparams.infer_batch_size):
                if sent_id >= prediction["predictions"].shape[0]:
                    break
                trans, output_length = nmt_utils.get_translation(
                    prediction["predictions"],
                    sent_id=sent_id,
                    tgt_eos=hparams.eos,
                    subword_option=hparams.subword_option)
                translations.append(trans)
                output_tokens.append(output_length)
        beam_id += 1
        if beam_id == hparams.beam_width:
            beam_id = 0

    if only_translate:
        trans_file = hparams.translate_file + '.trans.tok'
    else:
        trans_file = os.path.join(
            hparams.output_dir, "newstest2014_out_{}.tok.de".format(
                pred_estimator.get_variable_value(tf.GraphKeys.GLOBAL_STEP)))
    trans_dir = os.path.dirname(trans_file)
    if not tf.gfile.Exists(trans_dir):
        tf.gfile.MakeDirs(trans_dir)
    tf.logging.info("Writing to file %s" % trans_file)
    with codecs.getwriter("utf-8")(tf.gfile.GFile(trans_file,
                                                  mode="wb")) as trans_f:
        trans_f.write("")  # Write empty string to ensure file is created.
        for translation in translations:
            trans_f.write((translation + b"\n").decode("utf-8"))

    if only_translate:
        return None, benchmark_hook.get_average_speed_and_latencies(), sum(
            output_tokens)

    # Evaluation
    output_dir = os.path.join(pred_estimator.model_dir, "eval")
    tf.gfile.MakeDirs(output_dir)
    summary_writer = tf.summary.FileWriter(output_dir)

    ref_file = "%s.%s" % (hparams.test_prefix, hparams.tgt)
    # Hardcoded.
    metric = "bleu"
    score = get_sacrebleu(trans_file, hparams.detokenizer_file)

    misc_utils.print_out("bleu is %.5f" % score)
    with tf.Graph().as_default():
        summaries = []
        summaries.append(tf.Summary.Value(tag=metric, simple_value=score))
    tf_summary = tf.Summary(value=list(summaries))
    summary_writer.add_summary(
        tf_summary,
        pred_estimator.get_variable_value(tf.GraphKeys.GLOBAL_STEP))

    summary_writer.close()
    return pred_estimator.get_variable_value(
        tf.GraphKeys.GLOBAL_STEP
    ), score, benchmark_hook.get_average_speed_and_latencies(), sum(
        output_tokens)
Exemplo n.º 9
0
def get_metrics(hparams, model_fn, ckpt=None):
    """Run inference and compute metrics."""
    pred_estimator = tf.estimator.Estimator(model_fn=model_fn,
                                            model_dir=hparams.output_dir)

    # It's in the dataset.
    mlperf_log.gnmt_print(key=mlperf_log.EVAL_SIZE, value=3003)

    predictions = pred_estimator.predict(make_input_fn(
        hparams, tf.contrib.learn.ModeKeys.INFER),
                                         checkpoint_path=ckpt)
    translations = []
    beam_id = 0
    for prediction in predictions:
        # get the top translation.
        if beam_id == 0:
            for sent_id in range(hparams.infer_batch_size):
                if sent_id >= prediction["predictions"].shape[0]:
                    break
                trans = nmt_utils.get_translation(
                    prediction["predictions"],
                    sent_id=sent_id,
                    tgt_eos=hparams.eos,
                    subword_option=hparams.subword_option)
                translations.append(trans)
        beam_id += 1
        if beam_id == hparams.beam_width:
            beam_id = 0

    trans_file = os.path.join(
        hparams.output_dir, "newstest2014_out_{}.tok.de".format(
            pred_estimator.get_variable_value(tf.GraphKeys.GLOBAL_STEP)))
    trans_dir = os.path.dirname(trans_file)
    if not tf.gfile.Exists(trans_dir):
        tf.gfile.MakeDirs(trans_dir)
    tf.logging.info("Writing to file %s" % trans_file)
    with codecs.getwriter("utf-8")(tf.gfile.GFile(trans_file,
                                                  mode="wb")) as trans_f:
        trans_f.write("")  # Write empty string to ensure file is created.
        for translation in translations:
            trans_f.write((translation + b"\n").decode("utf-8"))

    # Evaluation
    output_dir = os.path.join(pred_estimator.model_dir, "eval")
    tf.gfile.MakeDirs(output_dir)
    summary_writer = tf.summary.FileWriter(output_dir)

    ref_file = "%s.%s" % (hparams.test_prefix, hparams.tgt)
    # Hardcoded.
    metric = "bleu"
    if hparams.use_borg:
        score = evaluation_utils.evaluate(ref_file, trans_file, metric,
                                          hparams.subword_option)
    else:
        score = get_sacrebleu(trans_file, hparams.detokenizer_file)

    misc_utils.print_out("bleu is %.2f" % score)
    with tf.Graph().as_default():
        summaries = []
        summaries.append(tf.Summary.Value(tag=metric, simple_value=score))
    tf_summary = tf.Summary(value=list(summaries))
    summary_writer.add_summary(
        tf_summary,
        pred_estimator.get_variable_value(tf.GraphKeys.GLOBAL_STEP))

    misc_utils.print_out("  %s: %.1f" % (metric, score))

    summary_writer.close()
    return score