コード例 #1
0
def prepare_decoder_input_output(tgt_words, tgt_len, vocab_table):
    """
    Args:
        tgt_words: tensor of word ids, [batch x max_len]
        tgt_len: vector of sentence lengths, [batch]
        vocab_table: instance of tf.vocab_lookup_table

    Returns:
        dec_input: tensor of word ids, [batch x max_len+1]
        dec_input_len: vector of sentence lengths, [batch]
        dec_output: tensor of word ids, [batch x max_len+1]
        dec_output_len: vector of sentence lengths, [batch]

    """
    start_token_id = vocab.get_token_id(vocab.START_TOKEN, vocab_table)
    stop_token_id = vocab.get_token_id(vocab.STOP_TOKEN, vocab_table)
    pad_token_id = vocab.get_token_id(vocab.PAD_TOKEN, vocab_table)

    dec_input = decoder.prepare_decoder_inputs(tgt_words, start_token_id)
    dec_input_len = seq.length_pre_embedding(dec_input)

    dec_output = decoder.prepare_decoder_output(tgt_words, tgt_len,
                                                stop_token_id, pad_token_id)
    dec_output_len = seq.length_pre_embedding(dec_output)

    return dec_input, dec_input_len, dec_output, dec_output_len
コード例 #2
0
def test_wtf():
    with tf.Graph().as_default():
        V, embed_matrix = vocab.read_word_embeddings(
            Path('../data') / 'word_vectors' / 'glove.6B.300d_yelp.txt',
            300,
            10000
        )

        table = vocab.create_vocab_lookup_tables(V)
        vocab_s2i = table[vocab.STR_TO_INT]
        vocab_i2s = table[vocab.INT_TO_STR]

        dataset = input_fn('../data/yelp_dataset_large_split/train.tsv', table, 64, 1)
        iter = dataset.make_initializable_iterator()

        (src, tgt, iw, dw), _ = iter.get_next()
        src_len = length_pre_embedding(src)
        tgt_len = length_pre_embedding(tgt)
        iw_len = length_pre_embedding(iw)
        dw_len = length_pre_embedding(dw)

        dec_inputs = decoder.prepare_decoder_inputs(tgt, vocab.get_token_id(vocab.START_TOKEN, vocab_s2i))

        dec_output = decoder.prepare_decoder_output(tgt, tgt_len, vocab.get_token_id(vocab.STOP_TOKEN, vocab_s2i),
                                                    vocab.get_token_id(vocab.PAD_TOKEN, vocab_s2i))

        t_src = vocab_i2s.lookup(src)
        t_tgt = vocab_i2s.lookup(tgt)
        t_iw = vocab_i2s.lookup(iw)
        t_dw = vocab_i2s.lookup(dw)

        t_do = vocab_i2s.lookup(dec_output)
        t_di = vocab_i2s.lookup(dec_inputs)

        with tf.Session() as sess:
            sess.run([tf.global_variables_initializer(), tf.local_variables_initializer(), tf.tables_initializer()])
            sess.run(iter.initializer)

            while True:
                try:
                    # src, tgt, iw, dw = sess.run([src, tgt, iw, dw])
                    ts, tt, tiw, tdw, tdo, tdi = sess.run([t_src, t_tgt, t_iw, t_dw, t_do, t_di])
                except:
                    break
コード例 #3
0
def editor_train(base_words,
                 extended_base_words,
                 output_words,
                 extended_output_words,
                 source_words,
                 target_words,
                 insert_words,
                 delete_words,
                 oov,
                 vocab_size,
                 hidden_dim,
                 agenda_dim,
                 edit_dim,
                 micro_edit_ev_dim,
                 num_heads,
                 num_encoder_layers,
                 num_decoder_layers,
                 attn_dim,
                 beam_width,
                 ctx_hidden_dim,
                 ctx_hidden_layer,
                 wa_hidden_dim,
                 wa_hidden_layer,
                 meve_hidden_dim,
                 meve_hidden_layers,
                 recons_dense_layers,
                 max_sent_length,
                 dropout_keep,
                 lamb_reg,
                 norm_eps,
                 norm_max,
                 kill_edit,
                 draw_edit,
                 swap_memory,
                 use_beam_decoder=False,
                 use_dropout=False,
                 no_insert_delete_attn=False,
                 enable_vae=True):
    batch_size = tf.shape(source_words)[0]

    # [batch]
    base_len = seq.length_pre_embedding(base_words)
    output_len = seq.length_pre_embedding(extended_output_words)
    src_len = seq.length_pre_embedding(source_words)
    tgt_len = seq.length_pre_embedding(target_words)
    iw_len = seq.length_pre_embedding(insert_words)
    dw_len = seq.length_pre_embedding(delete_words)

    # variable of shape [vocab_size, embed_dim]
    embeddings = vocab.get_embeddings()

    # [batch x max_len x embed_dim]
    base_word_embeds = vocab.embed_tokens(base_words)
    output_word_embeds = vocab.embed_tokens(output_words)
    # src_word_embeds = vocab.embed_tokens(source_words)
    # tgt_word_embeds = vocab.embed_tokens(target_words)
    insert_word_embeds = vocab.embed_tokens(insert_words)
    delete_word_embeds = vocab.embed_tokens(delete_words)

    # [batch x max_len x rnn_out_dim], [batch x rnn_out_dim]
    base_sent_hidden_states, base_sent_embed = encoder.source_sent_encoder(
        base_word_embeds,
        base_len,
        hidden_dim,
        num_encoder_layers,
        use_dropout=use_dropout,
        dropout_keep=dropout_keep,
        swap_memory=swap_memory)

    assert kill_edit == False and draw_edit == False

    # [batch x edit_dim]
    if kill_edit:
        edit_vector = tf.zeros(shape=(batch_size, edit_dim))
    else:
        if draw_edit:
            edit_vector = random_noise_encoder(batch_size, edit_dim, norm_max)
        else:
            edit_vector = accumulator_encoder(insert_word_embeds,
                                              delete_word_embeds,
                                              iw_len,
                                              dw_len,
                                              edit_dim,
                                              lamb_reg,
                                              norm_eps,
                                              norm_max,
                                              dropout_keep,
                                              enable_vae=enable_vae)

            wa_inserted, wa_deleted = (tf.constant([[0]]), tf.constant(
                [[0]])), (tf.constant([[0]]), tf.constant([[0]]))

    # [batch x agenda_dim]
    base_agenda = agn.linear(base_sent_embed, edit_vector, agenda_dim)

    train_dec_inp, train_dec_inp_len, \
    train_dec_out, train_dec_out_len = prepare_decoder_input_output(output_words, extended_output_words, output_len)

    train_dec_inp_extended = prepare_decoder_inputs(extended_output_words,
                                                    tf.cast(-1, tf.int64))

    train_decoder = decoder.train_decoder(
        base_agenda,
        embeddings,
        extended_base_words,
        oov,
        train_dec_inp,
        train_dec_inp_extended,
        base_sent_hidden_states,
        wa_inserted,
        wa_deleted,
        train_dec_inp_len,
        base_len,
        src_len,
        tgt_len,
        vocab_size,
        attn_dim,
        hidden_dim,
        num_decoder_layers,
        swap_memory,
        enable_dropout=use_dropout,
        dropout_keep=dropout_keep,
        no_insert_delete_attn=no_insert_delete_attn)

    if use_beam_decoder:
        infr_decoder = decoder.beam_eval_decoder(
            base_agenda,
            embeddings,
            extended_base_words,
            oov,
            vocab.get_token_id(vocab.START_TOKEN),
            vocab.get_token_id(vocab.STOP_TOKEN),
            base_sent_hidden_states,
            wa_inserted,
            wa_deleted,
            base_len,
            src_len,
            tgt_len,
            vocab_size,
            attn_dim,
            hidden_dim,
            num_decoder_layers,
            max_sent_length,
            beam_width,
            swap_memory,
            enable_dropout=use_dropout,
            dropout_keep=dropout_keep,
            no_insert_delete_attn=no_insert_delete_attn)
    else:
        infr_decoder = decoder.greedy_eval_decoder(
            base_agenda,
            embeddings,
            extended_base_words,
            oov,
            vocab.get_token_id(vocab.START_TOKEN),
            vocab.get_token_id(vocab.STOP_TOKEN),
            base_sent_hidden_states,
            wa_inserted,
            wa_deleted,
            base_len,
            src_len,
            tgt_len,
            vocab_size,
            attn_dim,
            hidden_dim,
            num_decoder_layers,
            max_sent_length,
            swap_memory,
            enable_dropout=use_dropout,
            dropout_keep=dropout_keep,
            no_insert_delete_attn=no_insert_delete_attn)

    edit_vector_recons = output_words_to_edit_vector(
        output_word_embeds, output_len, edit_dim, ctx_hidden_dim,
        ctx_hidden_layer, recons_dense_layers, swap_memory)
    optimizer.add_reconst_loss(edit_vector, edit_vector_recons)

    return train_decoder, infr_decoder, train_dec_out, train_dec_out_len
コード例 #4
0
def test_decoder_train(dataset_file, embedding_file):
    with tf.Graph().as_default():
        d_fn, gold_dataset = dataset_file
        e_fn, gold_embeds = embedding_file

        v, embed_matrix = vocab.read_word_embeddings(e_fn, EMBED_DIM)
        vocab_lookup = vocab.get_vocab_lookup(v)

        stop_token = tf.constant(bytes(vocab.STOP_TOKEN, encoding='utf8'),
                                 dtype=tf.string)
        stop_token_id = vocab_lookup.lookup(stop_token)

        start_token = tf.constant(bytes(vocab.START_TOKEN, encoding='utf8'),
                                  dtype=tf.string)
        start_token_id = vocab_lookup.lookup(start_token)

        pad_token = tf.constant(bytes(vocab.PAD_TOKEN, encoding='utf8'),
                                dtype=tf.string)
        pad_token_id = vocab_lookup.lookup(pad_token)

        dataset = neural_editor.input_fn(d_fn, vocab_lookup, BATCH_SIZE,
                                         NUM_EPOCH)
        iter = dataset.make_initializable_iterator()
        (src, tgt, inw, dlw), _ = iter.get_next()

        src_len = sequence.length_pre_embedding(src)

        tgt_len = sequence.length_pre_embedding(tgt)

        dec_inputs = decoder.prepare_decoder_inputs(tgt, start_token_id)
        dec_outputs = decoder.prepare_decoder_output(tgt, tgt_len,
                                                     stop_token_id,
                                                     pad_token_id)

        dec_inputs_len = sequence.length_pre_embedding(dec_inputs)
        dec_outputs_len = sequence.length_pre_embedding(dec_outputs)

        batch_size = tf.shape(src)[0]
        edit_vector = edit_encoder.random_noise_encoder(
            batch_size, EDIT_DIM, 14.0)

        embedding = tf.get_variable(
            'embeddings',
            shape=embed_matrix.shape,
            initializer=tf.constant_initializer(embed_matrix))

        src_embd = tf.nn.embedding_lookup(embedding, src)
        src_sent_embeds, final_states = encoder.source_sent_encoder(
            src_embd, src_len, 20, 3, 0.8)

        agn = agenda.linear(final_states, edit_vector, 4)

        dec_out = decoder.train_decoder(agn, embedding, dec_inputs,
                                        src_sent_embeds,
                                        tf.nn.embedding_lookup(embedding, inw),
                                        tf.nn.embedding_lookup(embedding, dlw),
                                        dec_inputs_len, src_len,
                                        sequence.length_pre_embedding(inw),
                                        sequence.length_pre_embedding(dlw), 5,
                                        20, 3, False)

        # eval_dec_out = decoder.greedy_eval_decoder(
        #     agn, embedding,
        #     start_token_id, stop_token_id,
        #     src_sent_embeds,
        #     tf.nn.embedding_lookup(embedding, inw),
        #     tf.nn.embedding_lookup(embedding, dlw),
        #     src_len, sequence.length_pre_embedding(inw), sequence.length_pre_embedding(dlw),
        #     5, 20, 3, 40
        # )

        eval_dec_out = decoder.beam_eval_decoder(
            agn, embedding, start_token_id, stop_token_id, src_sent_embeds,
            tf.nn.embedding_lookup(embedding, inw),
            tf.nn.embedding_lookup(embedding, dlw), src_len,
            sequence.length_pre_embedding(inw),
            sequence.length_pre_embedding(dlw), 5, 20, 3, 40)

        # saver = tf.train.Saver(write_version=tf.train.SaverDef.V1)
        # s = tf.summary.FileWriter('data/an')
        # s.add_graph(g)
        #
        # all_print = tf.get_collection('print')

        an, final_states, len = dec_out
        stacked = decoder.attention_score(dec_out)

        with tf.Session() as sess:
            sess.run([
                tf.global_variables_initializer(),
                tf.local_variables_initializer(),
                tf.tables_initializer()
            ])
            sess.run(iter.initializer)

            print(sess.run([eval_dec_out]))
コード例 #5
0
def test_decoder_prepares(dataset_file, embedding_file):
    with tf.Graph().as_default():
        d_fn, gold_dataset = dataset_file
        e_fn, gold_embeds = embedding_file

        v, embed_matrix = vocab.read_word_embeddings(e_fn, EMBED_DIM)
        vocab_lookup = vocab.get_vocab_lookup(v)

        stop_token = tf.constant(bytes(vocab.STOP_TOKEN, encoding='utf8'),
                                 dtype=tf.string)
        stop_token_id = vocab_lookup.lookup(stop_token)

        start_token = tf.constant(bytes(vocab.START_TOKEN, encoding='utf8'),
                                  dtype=tf.string)
        start_token_id = vocab_lookup.lookup(start_token)

        pad_token = tf.constant(bytes(vocab.PAD_TOKEN, encoding='utf8'),
                                dtype=tf.string)
        pad_token_id = vocab_lookup.lookup(pad_token)

        dataset = neural_editor.input_fn(d_fn, vocab_lookup, BATCH_SIZE,
                                         NUM_EPOCH)
        iter = dataset.make_initializable_iterator()
        (_, tgt, _, _), _ = iter.get_next()

        tgt_len = sequence.length_pre_embedding(tgt)

        dec_inputs = decoder.prepare_decoder_inputs(tgt, start_token_id)
        dec_outputs = decoder.prepare_decoder_output(tgt, tgt_len,
                                                     stop_token_id,
                                                     pad_token_id)

        dec_inputs_len = sequence.length_pre_embedding(dec_inputs)
        dec_outputs_len = sequence.length_pre_embedding(dec_outputs)

        dec_outputs_last = sequence.last_relevant(
            tf.expand_dims(dec_outputs, 2), dec_outputs_len)
        dec_outputs_last = tf.squeeze(dec_outputs_last)

        with tf.Session() as sess:
            sess.run([
                tf.global_variables_initializer(),
                tf.local_variables_initializer(),
                tf.tables_initializer()
            ])
            sess.run(iter.initializer)

            while True:
                try:
                    dec_inputs, dec_outputs, tgt_len, dil, dol, start_token_id, stop_token_id, dec_outputs_last, tgt = sess.run(
                        [
                            dec_inputs, dec_outputs, tgt_len, dec_inputs_len,
                            dec_outputs_len, start_token_id, stop_token_id,
                            dec_outputs_last, tgt
                        ])

                    assert list(dil) == list(dol) == list(tgt_len + 1)
                    assert list(dec_inputs[:, 0]) == list(
                        np.ones_like(dec_inputs[:, 0]) * start_token_id)
                    assert list(dec_outputs_last) == list(
                        np.ones_like(dec_outputs_last) * stop_token_id)
                except:
                    break
コード例 #6
0
def editor_train(base_words,
                 extended_base_words,
                 output_words,
                 extended_output_words,
                 source_words,
                 target_words,
                 insert_words,
                 delete_words,
                 oov,
                 vocab_size,
                 hidden_dim,
                 agenda_dim,
                 edit_dim,
                 micro_edit_ev_dim,
                 num_heads,
                 num_encoder_layers,
                 num_decoder_layers,
                 attn_dim,
                 beam_width,
                 ctx_hidden_dim,
                 ctx_hidden_layer,
                 wa_hidden_dim,
                 wa_hidden_layer,
                 meve_hidden_dim,
                 meve_hidden_layers,
                 max_sent_length,
                 dropout_keep,
                 lamb_reg,
                 norm_eps,
                 norm_max,
                 kill_edit,
                 draw_edit,
                 swap_memory,
                 use_beam_decoder=False,
                 use_dropout=False,
                 no_insert_delete_attn=False,
                 enable_vae=True):
    # [batch]
    base_len = seq.length_pre_embedding(base_words)
    output_len = seq.length_pre_embedding(extended_output_words)

    # variable of shape [vocab_size, embed_dim]
    embeddings = vocab.get_embeddings()

    # [batch x max_len x embed_dim]
    base_word_embeds = vocab.embed_tokens(base_words)

    # [batch x max_len x rnn_out_dim], [batch x rnn_out_dim]
    base_sent_hidden_states, base_sent_embed = encoder.source_sent_encoder(
        base_word_embeds,
        base_len,
        hidden_dim,
        num_encoder_layers,
        use_dropout=use_dropout,
        dropout_keep=dropout_keep,
        swap_memory=swap_memory)

    assert kill_edit == False and draw_edit == False

    # [batch x agenda_dim]
    base_agenda = linear(base_sent_embed, agenda_dim)

    train_dec_inp, train_dec_inp_len, \
    train_dec_out, train_dec_out_len = prepare_decoder_input_output(output_words, extended_output_words, output_len)

    train_dec_inp_extended = prepare_decoder_inputs(extended_output_words,
                                                    tf.cast(-1, tf.int64))

    train_decoder = decoder.train_decoder(
        base_agenda,
        embeddings,
        extended_base_words,
        oov,
        train_dec_inp,
        train_dec_inp_extended,
        base_sent_hidden_states,
        train_dec_inp_len,
        base_len,
        vocab_size,
        attn_dim,
        hidden_dim,
        num_decoder_layers,
        swap_memory,
        enable_dropout=use_dropout,
        dropout_keep=dropout_keep,
        no_insert_delete_attn=no_insert_delete_attn)

    if use_beam_decoder:
        infr_decoder = decoder.beam_eval_decoder(
            base_agenda,
            embeddings,
            extended_base_words,
            oov,
            vocab.get_token_id(vocab.START_TOKEN),
            vocab.get_token_id(vocab.STOP_TOKEN),
            base_sent_hidden_states,
            base_len,
            vocab_size,
            attn_dim,
            hidden_dim,
            num_decoder_layers,
            max_sent_length,
            beam_width,
            swap_memory,
            enable_dropout=use_dropout,
            dropout_keep=dropout_keep,
            no_insert_delete_attn=no_insert_delete_attn)
    else:
        infr_decoder = decoder.greedy_eval_decoder(
            base_agenda,
            embeddings,
            extended_base_words,
            oov,
            vocab.get_token_id(vocab.START_TOKEN),
            vocab.get_token_id(vocab.STOP_TOKEN),
            base_sent_hidden_states,
            base_len,
            vocab_size,
            attn_dim,
            hidden_dim,
            num_decoder_layers,
            max_sent_length,
            swap_memory,
            enable_dropout=use_dropout,
            dropout_keep=dropout_keep,
            no_insert_delete_attn=no_insert_delete_attn)

        add_decoder_attn_history_graph(infr_decoder)

    return train_decoder, infr_decoder, train_dec_out, train_dec_out_len