Exemple #1
0
def _decode_inference_indices(model, sess, output_infer,
                              output_infer_summary_prefix, inference_indices,
                              tgt_sos, tgt_eos, bpe_delimiter):
    """Decoding only a specific set of sentences."""
    utils.print_out("  decoding to output %s , num sents %d." %
                    (output_infer, len(inference_indices)))
    start_time = time.time()
    with codecs.getwriter("utf-8")(tf.gfile.GFile(output_infer,
                                                  mode="wb")) as trans_f:
        trans_f.write("")  # Write empty string to ensure file is created.
        for decode_id in inference_indices:
            nmt_outputs, infer_summary = model.decode(sess)

            # get text translation
            assert nmt_outputs.shape[0] == 1
            translation = nmt_utils.get_translation(
                nmt_outputs,
                sent_id=0,
                tgt_sos=tgt_sos,
                tgt_eos=tgt_eos,
                bpe_delimiter=bpe_delimiter)

            if infer_summary is not None:  # Attention models
                image_file = output_infer_summary_prefix + str(
                    decode_id) + ".png"
                utils.print_out("  save attention image to %s*" % image_file)
                image_summ = tf.Summary()
                image_summ.ParseFromString(infer_summary)
                with tf.gfile.GFile(image_file, mode="w") as img_f:
                    img_f.write(image_summ.value[0].image.encoded_image_string)

            trans_f.write("%s\n" % translation)
            utils.print_out(b"%s\n" % translation)
    utils.print_time("  done", start_time)
def _sample_decode(model, global_step, sess, hparams, iterator, src_data,
                   tgt_data, iterator_src_placeholder,
                   iterator_batch_size_placeholder, summary_writer):
    """Pick a sentence and decode."""
    decode_id = random.randint(0, len(src_data) - 1)
    utils.print_out("  # %d" % decode_id)

    iterator_feed_dict = {
        iterator_src_placeholder: [src_data[decode_id]],
        iterator_batch_size_placeholder: 1,
    }
    sess.run(iterator.initializer, feed_dict=iterator_feed_dict)

    nmt_outputs, attention_summary = model.decode(sess)

    if hparams.beam_width > 0:
        # get the top translation.
        nmt_outputs = nmt_outputs[0]

    translation = nmt_utils.get_translation(
        nmt_outputs,
        sent_id=0,
        tgt_eos=hparams.eos,
        subword_option=hparams.subword_option)
    #utils.print_out(b"    src: " + src_data[decode_id])
    #utils.print_out(b"    ref: " + tgt_data[decode_id])
    #utils.print_out(b"    nmt: " + translation)

    # Summary
    if attention_summary is not None:
        summary_writer.add_summary(attention_summary, global_step)
Exemple #3
0
def _sample_decode(model, global_step, sess, hparams, iterator, src_data,
                   tgt_data, iterator_src_placeholder,
                   iterator_batch_size_placeholder, summary_writer):
    """Pick a sentence and decode."""
    iterator_feed_dict = {
        iterator_src_placeholder: src_data[-hparams.infer_batch_size:],
        iterator_batch_size_placeholder: hparams.infer_batch_size,
    }
    sess.run(iterator.initializer, feed_dict=iterator_feed_dict)

    nmt_outputs, att_w_history, ext_w_history = model.decode(sess)

    if hparams.beam_width > 0:
        # get the top translation.
        nmt_outputs = nmt_outputs[0]

    nmt_outputs = np.asarray(nmt_outputs)

    outputs = []
    for i in range(hparams.infer_batch_size):
        tmp = {}
        translation = nmt_utils.get_translation(
            nmt_outputs,
            sent_id=i,
            tgt_sos=hparams.sos,
            tgt_eos=hparams.eos,
            bpe_delimiter=hparams.bpe_delimiter)
        if i <= 5:
            utils.print_out("    src: %s" %
                            src_data[-hparams.infer_batch_size + i])
            utils.print_out("    ref: %s" %
                            tgt_data[-hparams.infer_batch_size + i])
            utils.print_out(b"    nmt: %s" % translation)
        tmp['src'] = src_data[-hparams.infer_batch_size + i]
        tmp['ref'] = tgt_data[-hparams.infer_batch_size + i]
        tmp['nmt'] = translation
        if att_w_history is not None:
            tmp['attention_head'] = att_w_history[-hparams.infer_batch_size +
                                                  i]
        if ext_w_history is not None:
            for j, ext_head in enumerate(ext_w_history):
                tmp['ext_head_{0}'.format(j)] = ext_head[
                    -hparams.infer_batch_size + i]
        outputs.append(tmp)

    if hparams.record_w_history:
        with open(
                hparams.out_dir + '/heads_step_{0}.pickle'.format(global_step),
                'wb') as f:
            if len(outputs) > 0:
                pickle.dump(outputs, f)
Exemple #4
0
def _sample_decode(model, global_step, sess, hparams, iterator, src_data,
                   tgt_data, lbl_data, iterator_src_placeholder,
                   iterator_batch_size_placeholder, summary_writer):
    """Pick a sentence and decode."""
    decode_id = random.randint(0, len(src_data) - 1)
    utils.print_out("  # %d" % decode_id)

    iterator_feed_dict = {
        iterator_src_placeholder: [src_data[decode_id]],
        iterator_batch_size_placeholder: 1,
    }
    sess.run(iterator.initializer, feed_dict=iterator_feed_dict)

    if hparams.task == "joint":
        outputs, intent_pred, src_seq_length, attention_summary = model.decode(
            sess)
    elif hparams.task == "intent":
        intent_pred, attention_summary = model.decode(sess)

    if hparams.infer_mode == "beam_search":
        # get the top translation.
        outputs = outputs[0]

    utils.print_out("          intent: %s" % lbl_data[decode_id])
    utils.print_out("             src: %s" % src_data[decode_id])
    if hparams.task == "joint":
        translation = nmt_utils.get_translation(
            outputs,
            src_seq_length,
            sent_id=0,
            tgt_eos=hparams.eos,
            subword_option=hparams.subword_option)
        utils.print_out("      slot (ref): %s" % tgt_data[decode_id])
        utils.print_out(b"   intent (pred): %s" % intent_pred[0])
        utils.print_out(b"     slot (pred): %s\n" % translation)
    elif hparams.task == "intent":
        utils.print_out(b"   intent (pred): %s" % intent_pred[0])

    # Summary
    if attention_summary is not None:
        summary_writer.add_summary(attention_summary, global_step)
def decode(ckpt_path, inference_input_file, inference_output_file, hparams):

    model_creator = get_model_creator(hparams)
    infer_model = model_helper.create_infer_model(model_creator, hparams, None)
    sess, loaded_infer_model = start_sess_and_load_model(
        infer_model, ckpt_path)

    # Read data
    infer_data = load_data(inference_input_file, hparams)
    print("infer_data:", infer_data)

    for i in range(len(infer_data)):

        sess.run(infer_model.iterator.initializer,
                 feed_dict={
                     infer_model.src_placeholder: [infer_data[i]],
                     infer_model.batch_size_placeholder: 1
                 })

        if hparams.task == "joint":
            outputs, intent_pred, src_seq_length, attention_summary = \
                loaded_infer_model.decode(sess)
        elif hparams.task == "intent":
            intent_pred, attention_summary = loaded_infer_model.decode(sess)

        if hparams.infer_mode == "beam_search":
            # get the top translation.
            outputs = outputs[0]

        utils.print_out("             src: %s" % infer_data[i])
        if hparams.task == "joint":
            translation = nmt_utils.get_translation(
                outputs,
                src_seq_length,
                sent_id=0,
                tgt_eos=hparams.eos,
                subword_option=hparams.subword_option)
            utils.print_out(b"   intent (pred): %s" % intent_pred[0])
            utils.print_out(b"     slot (pred): %s\n" % translation)
        elif hparams.task == "intent":
            utils.print_out(b"   intent (pred): %s\n" % intent_pred[0])