예제 #1
0
def get_trace_summary(vocab_i2s,
                      pred_tokens, tgt_tokens,
                      src_words, inserted_words, deleted_words,
                      pred_len, tgt_len):
    if pred_tokens.shape.ndims > 2:
        pred_joined = metrics.join_beams(pred_tokens, pred_len)
    else:
        pred_joined = metrics.join_tokens(pred_tokens, pred_len)

    tgt_joined = metrics.join_tokens(tgt_tokens, tgt_len)
    src_joined = metrics.join_tokens(vocab_i2s.lookup(src_words), length_pre_embedding(src_words))
    iw_joined = metrics.join_tokens(vocab_i2s.lookup(inserted_words), length_pre_embedding(inserted_words), ', ')
    dw_joined = metrics.join_tokens(vocab_i2s.lookup(deleted_words), length_pre_embedding(deleted_words), ', ')

    return tf.concat([src_joined, iw_joined, dw_joined, tgt_joined, pred_joined], axis=1)
def test_join_beam_tokens():
    def generate_beam(min_val=10,
                      max_val=50,
                      useless_val=88,
                      min_len=5,
                      max_len=10,
                      beam_size=2):
        seqs = []
        lengths = []
        for _ in range(beam_size):
            seq_len = random.randint(min_len, max_len)
            seq = [
                bytes(str(random.randint(min_val, max_val)), encoding='utf8')
                for _ in range(seq_len)
            ]
            seq += [bytes(str(useless_val), encoding='utf8')
                    ] * (max_len - seq_len)

            seqs.append(seq)
            lengths.append(seq_len)

        seqs = list(zip(*seqs))

        return seqs, lengths

    batch_size = 3

    with tf.Graph().as_default():
        tokens, lengths = zip(*[generate_beam() for _ in range(batch_size)])
        tokens = np.array(tokens)
        tokens = tf.constant(tokens, dtype=tf.string)
        lengths = tf.constant(lengths)

        o = join_tokens(tokens, lengths)
        o = join_beams(tokens, lengths)

        with tf.Session() as sess:
            print(sess.run(o))