예제 #1
0
def _sample_decode(model, global_step, sess, hparams, iterator, src_data,
                   tgt_data, iterator_src_placeholder,
                   iterator_batch_size_placeholder, summary_writer):
    """Pick a sentence and decode."""
    decode_id = random.randint(0, len(src_data) - 1)
    utils.print_out("  # %d" % decode_id)

    iterator_feed_dict = {
        iterator_src_placeholder: [src_data[decode_id]],
        iterator_batch_size_placeholder: 1,
    }
    sess.run(iterator.initializer, feed_dict=iterator_feed_dict)

    nmt_outputs, attention_summary = model.decode(sess)

    if hparams.infer_mode == "beam_search":
        # get the top translation.
        nmt_outputs = nmt_outputs[0]

    translation = nmt_utils.get_translation(
        nmt_outputs,
        sent_id=0,
        tgt_eos=hparams.eos,
        subword_option=hparams.subword_option)
    utils.print_out("    src: %s" % src_data[decode_id])
    utils.print_out("    ref: %s" % tgt_data[decode_id])
    utils.print_out(b"    nmt: " + translation)

    # Summary
    if attention_summary is not None:
        summary_writer.add_summary(attention_summary, global_step)
예제 #2
0
def _decode_inference_indices(model, sess, output_infer,
                              output_infer_summary_prefix, inference_indices,
                              tgt_eos, subword_option):
    """Decoding only a specific set of sentences."""
    utils.print_out("  decoding to output %s , num sents %d." %
                    (output_infer, len(inference_indices)))
    start_time = time.time()
    with codecs.getwriter("utf-8")(tf.gfile.GFile(output_infer,
                                                  mode="wb")) as trans_f:
        trans_f.write("")  # Write empty string to ensure file is created.
        for decode_id in inference_indices:
            nmt_outputs, infer_summary = model.decode(sess)

            # get text translation
            assert nmt_outputs.shape[0] == 1
            translation = nmt_utils.get_translation(
                nmt_outputs,
                sent_id=0,
                tgt_eos=tgt_eos,
                subword_option=subword_option)

            if infer_summary is not None:  # Attention models
                image_file = output_infer_summary_prefix + str(
                    decode_id) + ".png"
                utils.print_out("  save attention image to %s*" % image_file)
                image_summ = tf.Summary()
                image_summ.ParseFromString(infer_summary)
                with tf.gfile.GFile(image_file, mode="w") as img_f:
                    # pylint: disable=no-member
                    img_f.write(image_summ.value[0].image.encoded_image_string)

            trans_f.write("%s\n" % translation)
            utils.print_out(translation + b"\n")
    utils.print_time("  done", start_time)