Beispiel #1
0
def prepare_decoder_input_output(tgt_words, tgt_len, vocab_table):
    """
    Args:
        tgt_words: tensor of word ids, [batch x max_len]
        tgt_len: vector of sentence lengths, [batch]
        vocab_table: instance of tf.vocab_lookup_table

    Returns:
        dec_input: tensor of word ids, [batch x max_len+1]
        dec_input_len: vector of sentence lengths, [batch]
        dec_output: tensor of word ids, [batch x max_len+1]
        dec_output_len: vector of sentence lengths, [batch]

    """
    start_token_id = vocab.get_token_id(vocab.START_TOKEN, vocab_table)
    stop_token_id = vocab.get_token_id(vocab.STOP_TOKEN, vocab_table)
    pad_token_id = vocab.get_token_id(vocab.PAD_TOKEN, vocab_table)

    dec_input = decoder.prepare_decoder_inputs(tgt_words, start_token_id)
    dec_input_len = seq.length_pre_embedding(dec_input)

    dec_output = decoder.prepare_decoder_output(tgt_words, tgt_len,
                                                stop_token_id, pad_token_id)
    dec_output_len = seq.length_pre_embedding(dec_output)

    return dec_input, dec_input_len, dec_output, dec_output_len
def test_wtf():
    with tf.Graph().as_default():
        V, embed_matrix = vocab.read_word_embeddings(
            Path('../data') / 'word_vectors' / 'glove.6B.300d_yelp.txt',
            300,
            10000
        )

        table = vocab.create_vocab_lookup_tables(V)
        vocab_s2i = table[vocab.STR_TO_INT]
        vocab_i2s = table[vocab.INT_TO_STR]

        dataset = input_fn('../data/yelp_dataset_large_split/train.tsv', table, 64, 1)
        iter = dataset.make_initializable_iterator()

        (src, tgt, iw, dw), _ = iter.get_next()
        src_len = length_pre_embedding(src)
        tgt_len = length_pre_embedding(tgt)
        iw_len = length_pre_embedding(iw)
        dw_len = length_pre_embedding(dw)

        dec_inputs = decoder.prepare_decoder_inputs(tgt, vocab.get_token_id(vocab.START_TOKEN, vocab_s2i))

        dec_output = decoder.prepare_decoder_output(tgt, tgt_len, vocab.get_token_id(vocab.STOP_TOKEN, vocab_s2i),
                                                    vocab.get_token_id(vocab.PAD_TOKEN, vocab_s2i))

        t_src = vocab_i2s.lookup(src)
        t_tgt = vocab_i2s.lookup(tgt)
        t_iw = vocab_i2s.lookup(iw)
        t_dw = vocab_i2s.lookup(dw)

        t_do = vocab_i2s.lookup(dec_output)
        t_di = vocab_i2s.lookup(dec_inputs)

        with tf.Session() as sess:
            sess.run([tf.global_variables_initializer(), tf.local_variables_initializer(), tf.tables_initializer()])
            sess.run(iter.initializer)

            while True:
                try:
                    # src, tgt, iw, dw = sess.run([src, tgt, iw, dw])
                    ts, tt, tiw, tdw, tdo, tdi = sess.run([t_src, t_tgt, t_iw, t_dw, t_do, t_di])
                except:
                    break
def test_decoder_train(dataset_file, embedding_file):
    with tf.Graph().as_default():
        d_fn, gold_dataset = dataset_file
        e_fn, gold_embeds = embedding_file

        v, embed_matrix = vocab.read_word_embeddings(e_fn, EMBED_DIM)
        vocab_lookup = vocab.get_vocab_lookup(v)

        stop_token = tf.constant(bytes(vocab.STOP_TOKEN, encoding='utf8'),
                                 dtype=tf.string)
        stop_token_id = vocab_lookup.lookup(stop_token)

        start_token = tf.constant(bytes(vocab.START_TOKEN, encoding='utf8'),
                                  dtype=tf.string)
        start_token_id = vocab_lookup.lookup(start_token)

        pad_token = tf.constant(bytes(vocab.PAD_TOKEN, encoding='utf8'),
                                dtype=tf.string)
        pad_token_id = vocab_lookup.lookup(pad_token)

        dataset = neural_editor.input_fn(d_fn, vocab_lookup, BATCH_SIZE,
                                         NUM_EPOCH)
        iter = dataset.make_initializable_iterator()
        (src, tgt, inw, dlw), _ = iter.get_next()

        src_len = sequence.length_pre_embedding(src)

        tgt_len = sequence.length_pre_embedding(tgt)

        dec_inputs = decoder.prepare_decoder_inputs(tgt, start_token_id)
        dec_outputs = decoder.prepare_decoder_output(tgt, tgt_len,
                                                     stop_token_id,
                                                     pad_token_id)

        dec_inputs_len = sequence.length_pre_embedding(dec_inputs)
        dec_outputs_len = sequence.length_pre_embedding(dec_outputs)

        batch_size = tf.shape(src)[0]
        edit_vector = edit_encoder.random_noise_encoder(
            batch_size, EDIT_DIM, 14.0)

        embedding = tf.get_variable(
            'embeddings',
            shape=embed_matrix.shape,
            initializer=tf.constant_initializer(embed_matrix))

        src_embd = tf.nn.embedding_lookup(embedding, src)
        src_sent_embeds, final_states = encoder.source_sent_encoder(
            src_embd, src_len, 20, 3, 0.8)

        agn = agenda.linear(final_states, edit_vector, 4)

        dec_out = decoder.train_decoder(agn, embedding, dec_inputs,
                                        src_sent_embeds,
                                        tf.nn.embedding_lookup(embedding, inw),
                                        tf.nn.embedding_lookup(embedding, dlw),
                                        dec_inputs_len, src_len,
                                        sequence.length_pre_embedding(inw),
                                        sequence.length_pre_embedding(dlw), 5,
                                        20, 3, False)

        # eval_dec_out = decoder.greedy_eval_decoder(
        #     agn, embedding,
        #     start_token_id, stop_token_id,
        #     src_sent_embeds,
        #     tf.nn.embedding_lookup(embedding, inw),
        #     tf.nn.embedding_lookup(embedding, dlw),
        #     src_len, sequence.length_pre_embedding(inw), sequence.length_pre_embedding(dlw),
        #     5, 20, 3, 40
        # )

        eval_dec_out = decoder.beam_eval_decoder(
            agn, embedding, start_token_id, stop_token_id, src_sent_embeds,
            tf.nn.embedding_lookup(embedding, inw),
            tf.nn.embedding_lookup(embedding, dlw), src_len,
            sequence.length_pre_embedding(inw),
            sequence.length_pre_embedding(dlw), 5, 20, 3, 40)

        # saver = tf.train.Saver(write_version=tf.train.SaverDef.V1)
        # s = tf.summary.FileWriter('data/an')
        # s.add_graph(g)
        #
        # all_print = tf.get_collection('print')

        an, final_states, len = dec_out
        stacked = decoder.attention_score(dec_out)

        with tf.Session() as sess:
            sess.run([
                tf.global_variables_initializer(),
                tf.local_variables_initializer(),
                tf.tables_initializer()
            ])
            sess.run(iter.initializer)

            print(sess.run([eval_dec_out]))
def test_decoder_prepares(dataset_file, embedding_file):
    with tf.Graph().as_default():
        d_fn, gold_dataset = dataset_file
        e_fn, gold_embeds = embedding_file

        v, embed_matrix = vocab.read_word_embeddings(e_fn, EMBED_DIM)
        vocab_lookup = vocab.get_vocab_lookup(v)

        stop_token = tf.constant(bytes(vocab.STOP_TOKEN, encoding='utf8'),
                                 dtype=tf.string)
        stop_token_id = vocab_lookup.lookup(stop_token)

        start_token = tf.constant(bytes(vocab.START_TOKEN, encoding='utf8'),
                                  dtype=tf.string)
        start_token_id = vocab_lookup.lookup(start_token)

        pad_token = tf.constant(bytes(vocab.PAD_TOKEN, encoding='utf8'),
                                dtype=tf.string)
        pad_token_id = vocab_lookup.lookup(pad_token)

        dataset = neural_editor.input_fn(d_fn, vocab_lookup, BATCH_SIZE,
                                         NUM_EPOCH)
        iter = dataset.make_initializable_iterator()
        (_, tgt, _, _), _ = iter.get_next()

        tgt_len = sequence.length_pre_embedding(tgt)

        dec_inputs = decoder.prepare_decoder_inputs(tgt, start_token_id)
        dec_outputs = decoder.prepare_decoder_output(tgt, tgt_len,
                                                     stop_token_id,
                                                     pad_token_id)

        dec_inputs_len = sequence.length_pre_embedding(dec_inputs)
        dec_outputs_len = sequence.length_pre_embedding(dec_outputs)

        dec_outputs_last = sequence.last_relevant(
            tf.expand_dims(dec_outputs, 2), dec_outputs_len)
        dec_outputs_last = tf.squeeze(dec_outputs_last)

        with tf.Session() as sess:
            sess.run([
                tf.global_variables_initializer(),
                tf.local_variables_initializer(),
                tf.tables_initializer()
            ])
            sess.run(iter.initializer)

            while True:
                try:
                    dec_inputs, dec_outputs, tgt_len, dil, dol, start_token_id, stop_token_id, dec_outputs_last, tgt = sess.run(
                        [
                            dec_inputs, dec_outputs, tgt_len, dec_inputs_len,
                            dec_outputs_len, start_token_id, stop_token_id,
                            dec_outputs_last, tgt
                        ])

                    assert list(dil) == list(dol) == list(tgt_len + 1)
                    assert list(dec_inputs[:, 0]) == list(
                        np.ones_like(dec_inputs[:, 0]) * start_token_id)
                    assert list(dec_outputs_last) == list(
                        np.ones_like(dec_outputs_last) * stop_token_id)
                except:
                    break