Beispiel #1
0
def prepare_decoder_input_output(tgt_words, tgt_len, vocab_table):
    """
    Args:
        tgt_words: tensor of word ids, [batch x max_len]
        tgt_len: vector of sentence lengths, [batch]
        vocab_table: instance of tf.vocab_lookup_table

    Returns:
        dec_input: tensor of word ids, [batch x max_len+1]
        dec_input_len: vector of sentence lengths, [batch]
        dec_output: tensor of word ids, [batch x max_len+1]
        dec_output_len: vector of sentence lengths, [batch]

    """
    start_token_id = vocab.get_token_id(vocab.START_TOKEN, vocab_table)
    stop_token_id = vocab.get_token_id(vocab.STOP_TOKEN, vocab_table)
    pad_token_id = vocab.get_token_id(vocab.PAD_TOKEN, vocab_table)

    dec_input = decoder.prepare_decoder_inputs(tgt_words, start_token_id)
    dec_input_len = seq.length_pre_embedding(dec_input)

    dec_output = decoder.prepare_decoder_output(tgt_words, tgt_len,
                                                stop_token_id, pad_token_id)
    dec_output_len = seq.length_pre_embedding(dec_output)

    return dec_input, dec_input_len, dec_output, dec_output_len
def get_trace_summary(vocab_i2s,
                      pred_tokens, tgt_tokens,
                      src_words, inserted_words, deleted_words,
                      pred_len, tgt_len):
    if pred_tokens.shape.ndims > 2:
        pred_joined = metrics.join_beams(pred_tokens, pred_len)
    else:
        pred_joined = metrics.join_tokens(pred_tokens, pred_len)

    tgt_joined = metrics.join_tokens(tgt_tokens, tgt_len)
    src_joined = metrics.join_tokens(vocab_i2s.lookup(src_words), length_pre_embedding(src_words))
    iw_joined = metrics.join_tokens(vocab_i2s.lookup(inserted_words), length_pre_embedding(inserted_words), ', ')
    dw_joined = metrics.join_tokens(vocab_i2s.lookup(deleted_words), length_pre_embedding(deleted_words), ', ')

    return tf.concat([src_joined, iw_joined, dw_joined, tgt_joined, pred_joined], axis=1)
Beispiel #3
0
    def encode_all(self, base_word_ids, source_word_ids, target_word_ids,
                   insert_word_ids, common_word_ids):
        batch_size = tf.shape(base_word_ids)[0]

        with tf.name_scope('encode_all'):
            base_len = seq.length_pre_embedding(base_word_ids)
            src_len = seq.length_pre_embedding(source_word_ids)
            tgt_len = seq.length_pre_embedding(target_word_ids)
            iw_len = seq.length_pre_embedding(insert_word_ids)
            cw_len = seq.length_pre_embedding(common_word_ids)

            base_encoded, base_attention_bias = self.encoder(
                base_word_ids, base_len)

            kill_edit = self.config.editor.kill_edit
            draw_edit = self.config.editor.draw_edit

            if self.config.editor.decoder.allow_mev_st_attn \
                    or self.config.editor.decoder.allow_mev_ts_attn:
                assert kill_edit == False and draw_edit == False

            if kill_edit:
                edit_vector = tf.zeros(
                    shape=(batch_size,
                           self.config.editor.edit_encoder.edit_dim))
                mev_st = mev_ts = None
            else:
                if draw_edit:
                    edit_vector = random_noise_encoder(
                        batch_size, self.config.editor.edit_encoder.edit_dim,
                        self.config.editor.norm_max)
                    mev_st = mev_ts = None
                else:
                    edit_vector, mev_st, mev_ts = self.edit_encoder(
                        source_word_ids,
                        target_word_ids,
                        insert_word_ids,
                        common_word_ids,
                        src_len,
                        tgt_len,
                        iw_len,
                        cw_len,
                    )

            encoder_outputs = (base_encoded, base_attention_bias)
            edit_encoder_outputs = (edit_vector, mev_st, mev_ts)

            return encoder_outputs, edit_encoder_outputs
    def _prepare_inputs(self, output_word_ids: tf.Tensor,
                        edit_vector: tf.Tensor):
        # Add start token to decoder inputs
        decoder_input_words = prepare_decoder_input(
            output_word_ids)  # [batch, output_len+1]
        decoder_input_max_len = tf.shape(decoder_input_words)[1]
        decoder_input_len = sequence.length_pre_embedding(
            decoder_input_words)  # [batch]

        # Get word embeddings
        decoder_input_embeds = self.embedding_layer(
            decoder_input_words)  # [batch, output_len+1, hidden_size)

        # Add positional encoding to the embeddings part
        with tf.name_scope('positional_encoding'):
            pos_encoding = model_utils.get_position_encoding(
                decoder_input_max_len, self.config.orig_hidden_size)
            decoder_input_embeds += pos_encoding

        decoder_input = decoder_input_embeds

        if self.config.enable_dropout and self.config.layer_postprocess_dropout > 0.:
            decoder_input = tf.nn.dropout(
                decoder_input, 1 - self.config.layer_postprocess_dropout)

        return decoder_input, decoder_input_len
def test_wtf():
    with tf.Graph().as_default():
        V, embed_matrix = vocab.read_word_embeddings(
            Path('../data') / 'word_vectors' / 'glove.6B.300d_yelp.txt',
            300,
            10000
        )

        table = vocab.create_vocab_lookup_tables(V)
        vocab_s2i = table[vocab.STR_TO_INT]
        vocab_i2s = table[vocab.INT_TO_STR]

        dataset = input_fn('../data/yelp_dataset_large_split/train.tsv', table, 64, 1)
        iter = dataset.make_initializable_iterator()

        (src, tgt, iw, dw), _ = iter.get_next()
        src_len = length_pre_embedding(src)
        tgt_len = length_pre_embedding(tgt)
        iw_len = length_pre_embedding(iw)
        dw_len = length_pre_embedding(dw)

        dec_inputs = decoder.prepare_decoder_inputs(tgt, vocab.get_token_id(vocab.START_TOKEN, vocab_s2i))

        dec_output = decoder.prepare_decoder_output(tgt, tgt_len, vocab.get_token_id(vocab.STOP_TOKEN, vocab_s2i),
                                                    vocab.get_token_id(vocab.PAD_TOKEN, vocab_s2i))

        t_src = vocab_i2s.lookup(src)
        t_tgt = vocab_i2s.lookup(tgt)
        t_iw = vocab_i2s.lookup(iw)
        t_dw = vocab_i2s.lookup(dw)

        t_do = vocab_i2s.lookup(dec_output)
        t_di = vocab_i2s.lookup(dec_inputs)

        with tf.Session() as sess:
            sess.run([tf.global_variables_initializer(), tf.local_variables_initializer(), tf.tables_initializer()])
            sess.run(iter.initializer)

            while True:
                try:
                    # src, tgt, iw, dw = sess.run([src, tgt, iw, dw])
                    ts, tt, tiw, tdw, tdo, tdi = sess.run([t_src, t_tgt, t_iw, t_dw, t_do, t_di])
                except:
                    break
def test_rnn_encoder(dataset_file, embedding_file):
    with tf.Graph().as_default():
        d_fn, gold_dataset = dataset_file
        e_fn, gold_embeds = embedding_file

        v, embed_matrix = vocab.read_word_embeddings(e_fn, EMBED_DIM)
        vocab_lookup = vocab.get_vocab_lookup(v)

        dataset = neural_editor.input_fn(d_fn, vocab_lookup, BATCH_SIZE,
                                         NUM_EPOCH)

        embedding = tf.get_variable(
            'embeddings',
            shape=embed_matrix.shape,
            initializer=tf.constant_initializer(embed_matrix))

        iter = dataset.make_initializable_iterator()
        (src, tgt, iw, dw), _ = iter.get_next()

        EDIT_DIM = 8
        output = ev.rnn_encoder(tf.nn.embedding_lookup(embedding, src),
                                tf.nn.embedding_lookup(embedding, tgt),
                                tf.nn.embedding_lookup(embedding, iw),
                                tf.nn.embedding_lookup(embedding, dw),
                                sequence.length_pre_embedding(src),
                                sequence.length_pre_embedding(tgt),
                                sequence.length_pre_embedding(iw),
                                sequence.length_pre_embedding(dw), 256, 2, 256,
                                1, EDIT_DIM, 100.0, 0.1, 14.0, 0.8)

        with tf.Session() as sess:
            sess.run([
                tf.global_variables_initializer(),
                tf.local_variables_initializer(),
                tf.tables_initializer()
            ])
            sess.run(iter.initializer)

            while True:
                try:
                    oeo = sess.run(output)
                    assert oeo.shape == (BATCH_SIZE, EDIT_DIM)
                except:
                    break
def add_extra_summary_avg_bleu(vocab_i2s, decoder_output, ref_words, collections=None):
    hypo_tokens = decoder.str_tokens(decoder_output, vocab_i2s)
    hypo_len = decoder.seq_length(decoder_output)

    ref_tokens = vocab_i2s.lookup(ref_words)
    ref_len = length_pre_embedding(ref_words)

    avg_bleu = get_avg_bleu_smmary(ref_tokens, hypo_tokens, ref_len, hypo_len)
    tf.summary.scalar('bleu', avg_bleu, collections)

    return avg_bleu
def add_extra_summary_avg_bleu(hypo_tokens,
                               hypo_len,
                               ref_words,
                               vocab_i2s,
                               collections=None):
    ref_tokens = vocab_i2s.lookup(ref_words)
    ref_len = length_pre_embedding(ref_words)

    avg_bleu = get_avg_bleu(ref_tokens, hypo_tokens, ref_len, hypo_len)
    tf.summary.scalar('bleu', avg_bleu, collections)

    return avg_bleu
Beispiel #9
0
def calculate_loss(logits, output_words, input_words, tgt_words,
                   label_smoothing, vocab_size):
    gold = decoder.prepare_decoder_output(
        output_words, sequence.length_pre_embedding(output_words))
    gold_len = sequence.length_pre_embedding(gold)

    gold_input = decoder.prepare_decoder_output(
        input_words, sequence.length_pre_embedding(input_words))
    gold_input_len = sequence.length_pre_embedding(gold_input)

    gold_tgt = decoder.prepare_decoder_output(
        tgt_words, sequence.length_pre_embedding(tgt_words))
    gold_tgt_len = sequence.length_pre_embedding(gold_tgt)

    main_loss, _ = optimizer.padded_cross_entropy_loss(logits, gold, gold_len,
                                                       label_smoothing,
                                                       vocab_size)

    input_loss, _ = optimizer.padded_cross_entropy_loss(
        logits, gold_input, gold_input_len, label_smoothing, vocab_size)
    tgt_loss, _ = optimizer.padded_cross_entropy_loss(logits, gold_tgt,
                                                      gold_tgt_len,
                                                      label_smoothing,
                                                      vocab_size)

    total_loss = main_loss - 1. / 50 * input_loss - 1. / 30 * tgt_loss

    return total_loss
Beispiel #10
0
def add_extra_summary_avg_bleu(vocab_i2s,
                               decoder_output,
                               tgt_words,
                               collections=None):
    pred_tokens = decoder.str_tokens(decoder_output, vocab_i2s)
    pred_len = decoder.seq_length(decoder_output)

    tgt_tokens = vocab_i2s.lookup(tgt_words)
    tgt_len = length_pre_embedding(tgt_words)

    avg_bleu = get_avg_bleu_smmary(tgt_tokens, pred_tokens, tgt_len, pred_len)
    tf.summary.scalar('bleu', avg_bleu, collections)

    return avg_bleu
Beispiel #11
0
def add_extra_summary_trace(vocab_i2s, decoder_output,
                            base_words, output_words,
                            src_words, tgt_words, inserted_words, deleted_words,
                            collections=None):
    pred_tokens = decoder.str_tokens(decoder_output, vocab_i2s)
    pred_len = decoder.seq_length(decoder_output)

    tgt_tokens = vocab_i2s.lookup(tgt_words)
    tgt_len = length_pre_embedding(tgt_words)

    trace_summary = get_trace_summary(vocab_i2s, pred_tokens, tgt_tokens, src_words, inserted_words, deleted_words,
                                      pred_len, tgt_len)
    tf.summary.text('trace', trace_summary, collections)

    return trace_summary
Beispiel #12
0
    def get_logits(self, encoded_inputs, output_word_ids):
        with tf.name_scope('logits'):
            encoder_outputs, edit_encoder_outputs = encoded_inputs

            base_sent_hidden_states, base_sent_attention_bias = encoder_outputs
            edit_vector, mev_st, mev_ts = edit_encoder_outputs

            output_len = seq.length_pre_embedding(output_word_ids)
            logits = self.decoder(output_word_ids,
                                  output_len,
                                  base_sent_hidden_states,
                                  base_sent_attention_bias,
                                  edit_vector,
                                  mev_st,
                                  mev_ts,
                                  mode='train')

            return logits
Beispiel #13
0
def add_extra_summary_trace(pred_tokens,
                            pred_len,
                            base_words,
                            output_words,
                            src_words,
                            tgt_words,
                            inserted_words,
                            deleted_words,
                            collections=None):
    vocab_i2s = vocab.get_vocab_lookup_tables()[vocab.INT_TO_STR]

    tgt_tokens = vocab_i2s.lookup(tgt_words)
    tgt_len = length_pre_embedding(tgt_words)

    trace_summary = get_trace(pred_tokens, tgt_tokens, src_words,
                              inserted_words, deleted_words, pred_len, tgt_len)
    tf.summary.text('trace', trace_summary, collections)

    return trace_summary
def test_length_pre_embedding():
    def generate_sequence(seq_len):
        seq = []
        for i in range(MAX_LEN):
            if i < seq_len:
                seq.append(random.randint(4, 1000))
            else:
                seq.append(0)

        return seq

    gold_seq_lengths = [random.randint(4, MAX_LEN) for _ in range(BATCH_SIZE - 1)] + [MAX_LEN]
    sequence_batch = [generate_sequence(l) for l in gold_seq_lengths]

    batch = np.array(sequence_batch, dtype=np.float32)

    tf.enable_eager_execution()
    lengths = length_pre_embedding(batch)

    assert lengths.shape == (BATCH_SIZE,)
    assert list(lengths.numpy()) == gold_seq_lengths
def test_context_encoder(dataset_file, embedding_file):
    with tf.Graph().as_default():
        d_fn, gold_dataset = dataset_file
        e_fn, gold_embeds = embedding_file

        v, embed_matrix = vocab.read_word_embeddings(e_fn, EMBED_DIM)
        vocab_lookup = vocab.get_vocab_lookup(v)

        dataset = neural_editor.input_fn(d_fn, vocab_lookup, BATCH_SIZE,
                                         NUM_EPOCH)

        embedding = tf.get_variable(
            'embeddings',
            shape=embed_matrix.shape,
            initializer=tf.constant_initializer(embed_matrix))

        iter = dataset.make_initializable_iterator()
        (_, _, src, _), _ = iter.get_next()

        src_len = sequence.length_pre_embedding(src)
        src_embd = tf.nn.embedding_lookup(embedding, src)

        output = ev.context_encoder(src_embd, src_len, HIDDEN_DIM, NUM_LAYER)

        with tf.Session() as sess:
            sess.run([
                tf.global_variables_initializer(),
                tf.local_variables_initializer(),
                tf.tables_initializer()
            ])
            sess.run(iter.initializer)

            while True:
                try:
                    oeo, o_src, o_src_len, o_src_embd = sess.run(
                        [output, src, src_len, src_embd])
                    assert oeo.shape == (BATCH_SIZE, o_src_len.max(),
                                         HIDDEN_DIM)
                except:
                    break
Beispiel #16
0
def test_encoder(dataset_file, embedding_file):
    d_fn, gold_dataset = dataset_file
    e_fn, gold_embeds = embedding_file

    v, embed_matrix = vocab.read_word_embeddings(e_fn, EMBED_DIM)
    vocab_lookup = vocab.get_vocab_lookup(v)

    dataset = neural_editor.input_fn(d_fn, vocab_lookup, BATCH_SIZE, NUM_EPOCH)

    embedding = tf.get_variable(
        'embeddings',
        shape=embed_matrix.shape,
        initializer=tf.constant_initializer(embed_matrix))

    iter = dataset.make_initializable_iterator()
    (src, _, _, _), _ = iter.get_next()

    src_len = sequence.length_pre_embedding(src)
    src_embd = tf.nn.embedding_lookup(embedding, src)

    encoder_output, _ = encoder.bidirectional_encoder(src_embd, src_len,
                                                      HIDDEN_DIM, NUM_LAYER,
                                                      0.9)

    with tf.Session() as sess:
        sess.run([
            tf.global_variables_initializer(),
            tf.local_variables_initializer(),
            tf.tables_initializer()
        ])
        sess.run(iter.initializer)

        oeo, o_src, o_src_len, o_src_embd = sess.run(
            [encoder_output, src, src_len, src_embd])

        for i in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
            print(i)

        assert oeo.shape == (BATCH_SIZE, o_src_len.max(), HIDDEN_DIM)
Beispiel #17
0
def editor_train(base_words,
                 source_words,
                 target_words,
                 insert_words,
                 delete_words,
                 embed_matrix,
                 vocab_table,
                 hidden_dim,
                 agenda_dim,
                 edit_dim,
                 num_encoder_layers,
                 num_decoder_layers,
                 attn_dim,
                 beam_width,
                 max_sent_length,
                 dropout_keep,
                 lamb_reg,
                 norm_eps,
                 norm_max,
                 kill_edit,
                 draw_edit,
                 swap_memory,
                 use_beam_decoder=False,
                 use_dropout=False,
                 no_insert_delete_attn=False,
                 enable_vae=True):
    batch_size = tf.shape(source_words)[0]

    # [batch]
    base_len = seq.length_pre_embedding(base_words)
    src_len = seq.length_pre_embedding(source_words)
    tgt_len = seq.length_pre_embedding(target_words)
    iw_len = seq.length_pre_embedding(insert_words)
    dw_len = seq.length_pre_embedding(delete_words)

    # variable of shape [vocab_size, embed_dim]
    embeddings = vocab.get_embeddings()

    # [batch x max_len x embed_dim]
    base_word_embeds = vocab.embed_tokens(base_words)
    # src_word_embeds = vocab.embed_tokens(source_words)
    # tgt_word_embeds = vocab.embed_tokens(target_words)
    insert_word_embeds = vocab.embed_tokens(insert_words)
    delete_word_embeds = vocab.embed_tokens(delete_words)

    # [batch x max_len x rnn_out_dim], [batch x rnn_out_dim]
    base_sent_hidden_states, base_sent_embed = encoder.source_sent_encoder(
        base_word_embeds,
        base_len,
        hidden_dim,
        num_encoder_layers,
        use_dropout=use_dropout,
        dropout_keep=dropout_keep,
        swap_memory=swap_memory)

    # [batch x edit_dim]
    if kill_edit:
        edit_vector = tf.zeros(shape=(batch_size, edit_dim))
    else:
        if draw_edit:
            edit_vector = edit_encoder.random_noise_encoder(
                batch_size, edit_dim, norm_max)
        else:
            edit_vector = edit_encoder.accumulator_encoder(
                insert_word_embeds,
                delete_word_embeds,
                iw_len,
                dw_len,
                edit_dim,
                lamb_reg,
                norm_eps,
                norm_max,
                dropout_keep,
                enable_vae=enable_vae)

    # [batch x agenda_dim]
    input_agenda = agn.linear(base_sent_embed, edit_vector, agenda_dim)

    train_dec_inp, train_dec_inp_len, \
    train_dec_out, train_dec_out_len = prepare_decoder_input_output(target_words, tgt_len, vocab_table)

    train_decoder = decoder.train_decoder(
        input_agenda,
        embeddings,
        train_dec_inp,
        base_sent_hidden_states,
        insert_word_embeds,
        delete_word_embeds,
        train_dec_inp_len,
        base_len,
        iw_len,
        dw_len,
        attn_dim,
        hidden_dim,
        num_decoder_layers,
        swap_memory,
        enable_dropout=use_dropout,
        dropout_keep=dropout_keep,
        no_insert_delete_attn=no_insert_delete_attn)

    if use_beam_decoder:
        infr_decoder = decoder.beam_eval_decoder(
            input_agenda,
            embeddings,
            vocab.get_token_id(vocab.START_TOKEN, vocab_table),
            vocab.get_token_id(vocab.STOP_TOKEN, vocab_table),
            base_sent_hidden_states,
            insert_word_embeds,
            delete_word_embeds,
            base_len,
            iw_len,
            dw_len,
            attn_dim,
            hidden_dim,
            num_decoder_layers,
            max_sent_length,
            beam_width,
            swap_memory,
            enable_dropout=use_dropout,
            dropout_keep=dropout_keep,
            no_insert_delete_attn=no_insert_delete_attn)
    else:
        infr_decoder = decoder.greedy_eval_decoder(
            input_agenda,
            embeddings,
            vocab.get_token_id(vocab.START_TOKEN, vocab_table),
            vocab.get_token_id(vocab.STOP_TOKEN, vocab_table),
            base_sent_hidden_states,
            insert_word_embeds,
            delete_word_embeds,
            base_len,
            iw_len,
            dw_len,
            attn_dim,
            hidden_dim,
            num_decoder_layers,
            max_sent_length,
            swap_memory,
            enable_dropout=use_dropout,
            dropout_keep=dropout_keep,
            no_insert_delete_attn=no_insert_delete_attn)

    return train_decoder, infr_decoder, train_dec_out, train_dec_out_len
def editor_train(base_words,
                 extended_base_words,
                 output_words,
                 extended_output_words,
                 source_words,
                 target_words,
                 insert_words,
                 delete_words,
                 oov,
                 vocab_size,
                 hidden_dim,
                 agenda_dim,
                 edit_dim,
                 micro_edit_ev_dim,
                 num_heads,
                 num_encoder_layers,
                 num_decoder_layers,
                 attn_dim,
                 beam_width,
                 ctx_hidden_dim,
                 ctx_hidden_layer,
                 wa_hidden_dim,
                 wa_hidden_layer,
                 meve_hidden_dim,
                 meve_hidden_layers,
                 max_sent_length,
                 dropout_keep,
                 lamb_reg,
                 norm_eps,
                 norm_max,
                 kill_edit,
                 draw_edit,
                 swap_memory,
                 use_beam_decoder=False,
                 use_dropout=False,
                 no_insert_delete_attn=False,
                 enable_vae=True):
    # [batch]
    base_len = seq.length_pre_embedding(base_words)
    output_len = seq.length_pre_embedding(extended_output_words)

    # variable of shape [vocab_size, embed_dim]
    embeddings = vocab.get_embeddings()

    # [batch x max_len x embed_dim]
    base_word_embeds = vocab.embed_tokens(base_words)

    # [batch x max_len x rnn_out_dim], [batch x rnn_out_dim]
    base_sent_hidden_states, base_sent_embed = encoder.source_sent_encoder(
        base_word_embeds,
        base_len,
        hidden_dim,
        num_encoder_layers,
        use_dropout=use_dropout,
        dropout_keep=dropout_keep,
        swap_memory=swap_memory)

    assert kill_edit == False and draw_edit == False

    # [batch x agenda_dim]
    base_agenda = linear(base_sent_embed, agenda_dim)

    train_dec_inp, train_dec_inp_len, \
    train_dec_out, train_dec_out_len = prepare_decoder_input_output(output_words, extended_output_words, output_len)

    train_dec_inp_extended = prepare_decoder_inputs(extended_output_words,
                                                    tf.cast(-1, tf.int64))

    train_decoder = decoder.train_decoder(
        base_agenda,
        embeddings,
        extended_base_words,
        oov,
        train_dec_inp,
        train_dec_inp_extended,
        base_sent_hidden_states,
        train_dec_inp_len,
        base_len,
        vocab_size,
        attn_dim,
        hidden_dim,
        num_decoder_layers,
        swap_memory,
        enable_dropout=use_dropout,
        dropout_keep=dropout_keep,
        no_insert_delete_attn=no_insert_delete_attn)

    if use_beam_decoder:
        infr_decoder = decoder.beam_eval_decoder(
            base_agenda,
            embeddings,
            extended_base_words,
            oov,
            vocab.get_token_id(vocab.START_TOKEN),
            vocab.get_token_id(vocab.STOP_TOKEN),
            base_sent_hidden_states,
            base_len,
            vocab_size,
            attn_dim,
            hidden_dim,
            num_decoder_layers,
            max_sent_length,
            beam_width,
            swap_memory,
            enable_dropout=use_dropout,
            dropout_keep=dropout_keep,
            no_insert_delete_attn=no_insert_delete_attn)
    else:
        infr_decoder = decoder.greedy_eval_decoder(
            base_agenda,
            embeddings,
            extended_base_words,
            oov,
            vocab.get_token_id(vocab.START_TOKEN),
            vocab.get_token_id(vocab.STOP_TOKEN),
            base_sent_hidden_states,
            base_len,
            vocab_size,
            attn_dim,
            hidden_dim,
            num_decoder_layers,
            max_sent_length,
            swap_memory,
            enable_dropout=use_dropout,
            dropout_keep=dropout_keep,
            no_insert_delete_attn=no_insert_delete_attn)

        add_decoder_attn_history_graph(infr_decoder)

    return train_decoder, infr_decoder, train_dec_out, train_dec_out_len
def editor_train(base_words,
                 extended_base_words,
                 output_words,
                 extended_output_words,
                 source_words,
                 target_words,
                 insert_words,
                 delete_words,
                 oov,
                 vocab_size,
                 hidden_dim,
                 agenda_dim,
                 edit_dim,
                 micro_edit_ev_dim,
                 num_heads,
                 num_encoder_layers,
                 num_decoder_layers,
                 attn_dim,
                 beam_width,
                 ctx_hidden_dim,
                 ctx_hidden_layer,
                 wa_hidden_dim,
                 wa_hidden_layer,
                 meve_hidden_dim,
                 meve_hidden_layers,
                 recons_dense_layers,
                 max_sent_length,
                 dropout_keep,
                 lamb_reg,
                 norm_eps,
                 norm_max,
                 kill_edit,
                 draw_edit,
                 swap_memory,
                 use_beam_decoder=False,
                 use_dropout=False,
                 no_insert_delete_attn=False,
                 enable_vae=True):
    batch_size = tf.shape(source_words)[0]

    # [batch]
    base_len = seq.length_pre_embedding(base_words)
    output_len = seq.length_pre_embedding(extended_output_words)
    src_len = seq.length_pre_embedding(source_words)
    tgt_len = seq.length_pre_embedding(target_words)
    iw_len = seq.length_pre_embedding(insert_words)
    dw_len = seq.length_pre_embedding(delete_words)

    # variable of shape [vocab_size, embed_dim]
    embeddings = vocab.get_embeddings()

    # [batch x max_len x embed_dim]
    base_word_embeds = vocab.embed_tokens(base_words)
    output_word_embeds = vocab.embed_tokens(output_words)
    # src_word_embeds = vocab.embed_tokens(source_words)
    # tgt_word_embeds = vocab.embed_tokens(target_words)
    insert_word_embeds = vocab.embed_tokens(insert_words)
    delete_word_embeds = vocab.embed_tokens(delete_words)

    # [batch x max_len x rnn_out_dim], [batch x rnn_out_dim]
    base_sent_hidden_states, base_sent_embed = encoder.source_sent_encoder(
        base_word_embeds,
        base_len,
        hidden_dim,
        num_encoder_layers,
        use_dropout=use_dropout,
        dropout_keep=dropout_keep,
        swap_memory=swap_memory)

    assert kill_edit == False and draw_edit == False

    # [batch x edit_dim]
    if kill_edit:
        edit_vector = tf.zeros(shape=(batch_size, edit_dim))
    else:
        if draw_edit:
            edit_vector = random_noise_encoder(batch_size, edit_dim, norm_max)
        else:
            edit_vector = accumulator_encoder(insert_word_embeds,
                                              delete_word_embeds,
                                              iw_len,
                                              dw_len,
                                              edit_dim,
                                              lamb_reg,
                                              norm_eps,
                                              norm_max,
                                              dropout_keep,
                                              enable_vae=enable_vae)

            wa_inserted, wa_deleted = (tf.constant([[0]]), tf.constant(
                [[0]])), (tf.constant([[0]]), tf.constant([[0]]))

    # [batch x agenda_dim]
    base_agenda = agn.linear(base_sent_embed, edit_vector, agenda_dim)

    train_dec_inp, train_dec_inp_len, \
    train_dec_out, train_dec_out_len = prepare_decoder_input_output(output_words, extended_output_words, output_len)

    train_dec_inp_extended = prepare_decoder_inputs(extended_output_words,
                                                    tf.cast(-1, tf.int64))

    train_decoder = decoder.train_decoder(
        base_agenda,
        embeddings,
        extended_base_words,
        oov,
        train_dec_inp,
        train_dec_inp_extended,
        base_sent_hidden_states,
        wa_inserted,
        wa_deleted,
        train_dec_inp_len,
        base_len,
        src_len,
        tgt_len,
        vocab_size,
        attn_dim,
        hidden_dim,
        num_decoder_layers,
        swap_memory,
        enable_dropout=use_dropout,
        dropout_keep=dropout_keep,
        no_insert_delete_attn=no_insert_delete_attn)

    if use_beam_decoder:
        infr_decoder = decoder.beam_eval_decoder(
            base_agenda,
            embeddings,
            extended_base_words,
            oov,
            vocab.get_token_id(vocab.START_TOKEN),
            vocab.get_token_id(vocab.STOP_TOKEN),
            base_sent_hidden_states,
            wa_inserted,
            wa_deleted,
            base_len,
            src_len,
            tgt_len,
            vocab_size,
            attn_dim,
            hidden_dim,
            num_decoder_layers,
            max_sent_length,
            beam_width,
            swap_memory,
            enable_dropout=use_dropout,
            dropout_keep=dropout_keep,
            no_insert_delete_attn=no_insert_delete_attn)
    else:
        infr_decoder = decoder.greedy_eval_decoder(
            base_agenda,
            embeddings,
            extended_base_words,
            oov,
            vocab.get_token_id(vocab.START_TOKEN),
            vocab.get_token_id(vocab.STOP_TOKEN),
            base_sent_hidden_states,
            wa_inserted,
            wa_deleted,
            base_len,
            src_len,
            tgt_len,
            vocab_size,
            attn_dim,
            hidden_dim,
            num_decoder_layers,
            max_sent_length,
            swap_memory,
            enable_dropout=use_dropout,
            dropout_keep=dropout_keep,
            no_insert_delete_attn=no_insert_delete_attn)

    edit_vector_recons = output_words_to_edit_vector(
        output_word_embeds, output_len, edit_dim, ctx_hidden_dim,
        ctx_hidden_layer, recons_dense_layers, swap_memory)
    optimizer.add_reconst_loss(edit_vector, edit_vector_recons)

    return train_decoder, infr_decoder, train_dec_out, train_dec_out_len
Beispiel #20
0
def model_fn(features, mode, config, embedding_matrix, vocab_tables):
    if mode == tf.estimator.ModeKeys.PREDICT:
        base_words, _, src_words, tgt_words, inserted_words, commong_words = features
        output_words = tgt_words
    else:
        base_words, output_words, src_words, tgt_words, inserted_words, commong_words = features

    is_training = mode == tf.estimator.ModeKeys.TRAIN
    tf.add_to_collection('is_training', is_training)

    if mode != tf.estimator.ModeKeys.TRAIN:
        config.put('editor.enable_dropout', False)
        config.put('editor.dropout_keep', 1.0)
        config.put('editor.dropout', 0.0)

        config.put('editor.transformer.enable_dropout', False)
        config.put('editor.transformer.layer_postprocess_dropout', 0.0)
        config.put('editor.transformer.attention_dropout', 0.0)
        config.put('editor.transformer.relu_dropout', 0.0)

    vocab.init_embeddings(embedding_matrix)
    EmbeddingSharedWeights.init_from_embedding_matrix()

    editor_model = Editor(config)
    logits, beam_prediction = editor_model(base_words, src_words, tgt_words,
                                           inserted_words, commong_words,
                                           output_words)

    targets = decoder.prepare_decoder_output(
        output_words, sequence.length_pre_embedding(output_words))
    target_lengths = sequence.length_pre_embedding(targets)

    vocab_size = embedding_matrix.shape[0]
    loss, weights = optimizer.padded_cross_entropy_loss(
        logits, targets, target_lengths, config.optim.label_smoothing,
        vocab_size)

    train_op = optimizer.get_train_op(loss, config)

    tf.logging.info("Trainable variable")
    for i in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
        tf.logging.info(str(i))

    tf.logging.info("Num of Trainable parameters")
    tf.logging.info(
        np.sum([
            np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()
        ]))

    if mode == tf.estimator.ModeKeys.TRAIN:
        decoded_ids = decoder.logits_to_decoded_ids(logits)
        ops = add_extra_summary(config,
                                decoded_ids,
                                target_lengths,
                                base_words,
                                output_words,
                                src_words,
                                tgt_words,
                                inserted_words,
                                commong_words,
                                collections=['extra'])

        hooks = [
            get_train_extra_summary_writer(config),
            get_extra_summary_logger(ops, config),
        ]

        if config.get('logger.enable_profiler', False):
            hooks.append(get_profiler_hook(config))

        return tf.estimator.EstimatorSpec(mode,
                                          train_op=train_op,
                                          loss=loss,
                                          training_hooks=hooks)

    elif mode == tf.estimator.ModeKeys.EVAL:
        decoded_ids = decoder.logits_to_decoded_ids(logits)
        ops = add_extra_summary(config,
                                decoded_ids,
                                target_lengths,
                                base_words,
                                output_words,
                                src_words,
                                tgt_words,
                                inserted_words,
                                commong_words,
                                collections=['extra'])

        return tf.estimator.EstimatorSpec(
            mode,
            loss=loss,
            evaluation_hooks=[get_extra_summary_logger(ops, config)],
            eval_metric_ops={'bleu': tf_metrics.streaming_mean(ops[ES_BLEU])})

    elif mode == tf.estimator.ModeKeys.PREDICT:
        decoded_ids, decoded_lengths, scores = beam_prediction
        tokens = decoder.str_tokens(decoded_ids)

        preds = {
            'str_tokens': tf.transpose(tokens, [0, 2, 1]),
            'decoded_ids': tf.transpose(decoded_ids, [0, 2, 1]),
            'lengths': decoded_lengths,
            'joined': metrics.join_tokens(tokens, decoded_lengths)
        }

        tmee_attentions = tf.get_collection(
            'TransformerMicroEditExtractor_Attentions')
        if len(tmee_attentions) > 0:
            preds.update({
                'tmee_attentions_st_enc_self': tmee_attentions[0][0],
                'tmee_attentions_st_dec_self': tmee_attentions[0][1],
                'tmee_attentions_st_dec_enc': tmee_attentions[0][2],
                'tmee_attentions_ts_enc_self': tmee_attentions[1][0],
                'tmee_attentions_ts_dec_self': tmee_attentions[1][1],
                'tmee_attentions_ts_dec_enc': tmee_attentions[1][2],
                'src_words': src_words,
                'tgt_words': tgt_words,
                'base_words': base_words,
                'output_words': output_words
            })

        return tf.estimator.EstimatorSpec(mode, predictions=preds)
Beispiel #21
0
def editor_train(base_words,
                 output_words,
                 source_words,
                 target_words,
                 insert_words,
                 delete_words,
                 hidden_dim,
                 agenda_dim,
                 edit_dim,
                 micro_edit_ev_dim,
                 num_heads,
                 num_encoder_layers,
                 num_decoder_layers,
                 attn_dim,
                 beam_width,
                 ctx_hidden_dim,
                 ctx_hidden_layer,
                 wa_hidden_dim,
                 wa_hidden_layer,
                 meve_hidden_dim,
                 meve_hidden_layers,
                 max_sent_length,
                 dropout_keep,
                 lamb_reg,
                 norm_eps,
                 norm_max,
                 kill_edit,
                 draw_edit,
                 swap_memory,
                 use_beam_decoder=False,
                 use_dropout=False,
                 no_insert_delete_attn=False,
                 enable_vae=True):
    batch_size = tf.shape(source_words)[0]

    # [batch]
    base_len = seq.length_pre_embedding(base_words)
    output_len = seq.length_pre_embedding(output_words)
    src_len = seq.length_pre_embedding(source_words)
    tgt_len = seq.length_pre_embedding(target_words)
    iw_len = seq.length_pre_embedding(insert_words)
    dw_len = seq.length_pre_embedding(delete_words)

    # variable of shape [vocab_size, embed_dim]
    embeddings = vocab.get_embeddings()

    # [batch x max_len x embed_dim]
    base_word_embeds = vocab.embed_tokens(base_words)
    src_word_embeds = vocab.embed_tokens(source_words)
    tgt_word_embeds = vocab.embed_tokens(target_words)
    insert_word_embeds = vocab.embed_tokens(insert_words)
    delete_word_embeds = vocab.embed_tokens(delete_words)

    sent_encoder = tf.make_template('sent_encoder',
                                    encoder.source_sent_encoder,
                                    hidden_dim=hidden_dim,
                                    num_layer=num_encoder_layers,
                                    swap_memory=swap_memory,
                                    use_dropout=use_dropout,
                                    dropout_keep=dropout_keep)

    # [batch x max_len x rnn_out_dim], [batch x rnn_out_dim]
    base_sent_hidden_states, base_sent_embed = sent_encoder(
        base_word_embeds, base_len)

    assert kill_edit == False and draw_edit == False

    # [batch x edit_dim]
    if kill_edit:
        edit_vector = tf.zeros(shape=(batch_size, edit_dim))
    else:
        if draw_edit:
            edit_vector = random_noise_encoder(batch_size, edit_dim, norm_max)
        else:
            edit_vector, wa_inserted, wa_deleted = attn_encoder(
                src_word_embeds,
                tgt_word_embeds,
                insert_word_embeds,
                delete_word_embeds,
                src_len,
                tgt_len,
                iw_len,
                dw_len,
                ctx_hidden_dim,
                ctx_hidden_layer,
                wa_hidden_dim,
                wa_hidden_layer,
                meve_hidden_dim,
                meve_hidden_layers,
                edit_dim,
                micro_edit_ev_dim,
                num_heads,
                lamb_reg,
                norm_eps,
                norm_max,
                sent_encoder,
                use_dropout=use_dropout,
                dropout_keep=dropout_keep,
                swap_memory=swap_memory,
                enable_vae=enable_vae)

    # [batch x agenda_dim]
    input_agenda = agn.linear(base_sent_embed, edit_vector, agenda_dim)

    train_dec_inp, train_dec_inp_len, \
    train_dec_out, train_dec_out_len = prepare_decoder_input_output(output_words, output_len, None)

    train_decoder = decoder.train_decoder(
        input_agenda,
        embeddings,
        train_dec_inp,
        base_sent_hidden_states,
        wa_inserted,
        wa_deleted,
        train_dec_inp_len,
        base_len,
        iw_len,
        dw_len,
        attn_dim,
        hidden_dim,
        num_decoder_layers,
        swap_memory,
        enable_dropout=use_dropout,
        dropout_keep=dropout_keep,
        no_insert_delete_attn=no_insert_delete_attn)

    if use_beam_decoder:
        infr_decoder = decoder.beam_eval_decoder(
            input_agenda,
            embeddings,
            vocab.get_token_id(vocab.START_TOKEN),
            vocab.get_token_id(vocab.STOP_TOKEN),
            base_sent_hidden_states,
            wa_inserted,
            wa_deleted,
            base_len,
            iw_len,
            dw_len,
            attn_dim,
            hidden_dim,
            num_decoder_layers,
            max_sent_length,
            beam_width,
            swap_memory,
            enable_dropout=use_dropout,
            dropout_keep=dropout_keep,
            no_insert_delete_attn=no_insert_delete_attn)
    else:
        infr_decoder = decoder.greedy_eval_decoder(
            input_agenda,
            embeddings,
            vocab.get_token_id(vocab.START_TOKEN),
            vocab.get_token_id(vocab.STOP_TOKEN),
            base_sent_hidden_states,
            wa_inserted,
            wa_deleted,
            base_len,
            iw_len,
            dw_len,
            attn_dim,
            hidden_dim,
            num_decoder_layers,
            max_sent_length,
            swap_memory,
            enable_dropout=use_dropout,
            dropout_keep=dropout_keep,
            no_insert_delete_attn=no_insert_delete_attn)

    return train_decoder, infr_decoder, train_dec_out, train_dec_out_len
def test_decoder_prepares(dataset_file, embedding_file):
    with tf.Graph().as_default():
        d_fn, gold_dataset = dataset_file
        e_fn, gold_embeds = embedding_file

        v, embed_matrix = vocab.read_word_embeddings(e_fn, EMBED_DIM)
        vocab_lookup = vocab.get_vocab_lookup(v)

        stop_token = tf.constant(bytes(vocab.STOP_TOKEN, encoding='utf8'),
                                 dtype=tf.string)
        stop_token_id = vocab_lookup.lookup(stop_token)

        start_token = tf.constant(bytes(vocab.START_TOKEN, encoding='utf8'),
                                  dtype=tf.string)
        start_token_id = vocab_lookup.lookup(start_token)

        pad_token = tf.constant(bytes(vocab.PAD_TOKEN, encoding='utf8'),
                                dtype=tf.string)
        pad_token_id = vocab_lookup.lookup(pad_token)

        dataset = neural_editor.input_fn(d_fn, vocab_lookup, BATCH_SIZE,
                                         NUM_EPOCH)
        iter = dataset.make_initializable_iterator()
        (_, tgt, _, _), _ = iter.get_next()

        tgt_len = sequence.length_pre_embedding(tgt)

        dec_inputs = decoder.prepare_decoder_inputs(tgt, start_token_id)
        dec_outputs = decoder.prepare_decoder_output(tgt, tgt_len,
                                                     stop_token_id,
                                                     pad_token_id)

        dec_inputs_len = sequence.length_pre_embedding(dec_inputs)
        dec_outputs_len = sequence.length_pre_embedding(dec_outputs)

        dec_outputs_last = sequence.last_relevant(
            tf.expand_dims(dec_outputs, 2), dec_outputs_len)
        dec_outputs_last = tf.squeeze(dec_outputs_last)

        with tf.Session() as sess:
            sess.run([
                tf.global_variables_initializer(),
                tf.local_variables_initializer(),
                tf.tables_initializer()
            ])
            sess.run(iter.initializer)

            while True:
                try:
                    dec_inputs, dec_outputs, tgt_len, dil, dol, start_token_id, stop_token_id, dec_outputs_last, tgt = sess.run(
                        [
                            dec_inputs, dec_outputs, tgt_len, dec_inputs_len,
                            dec_outputs_len, start_token_id, stop_token_id,
                            dec_outputs_last, tgt
                        ])

                    assert list(dil) == list(dol) == list(tgt_len + 1)
                    assert list(dec_inputs[:, 0]) == list(
                        np.ones_like(dec_inputs[:, 0]) * start_token_id)
                    assert list(dec_outputs_last) == list(
                        np.ones_like(dec_outputs_last) * stop_token_id)
                except:
                    break
def test_decoder_train(dataset_file, embedding_file):
    with tf.Graph().as_default():
        d_fn, gold_dataset = dataset_file
        e_fn, gold_embeds = embedding_file

        v, embed_matrix = vocab.read_word_embeddings(e_fn, EMBED_DIM)
        vocab_lookup = vocab.get_vocab_lookup(v)

        stop_token = tf.constant(bytes(vocab.STOP_TOKEN, encoding='utf8'),
                                 dtype=tf.string)
        stop_token_id = vocab_lookup.lookup(stop_token)

        start_token = tf.constant(bytes(vocab.START_TOKEN, encoding='utf8'),
                                  dtype=tf.string)
        start_token_id = vocab_lookup.lookup(start_token)

        pad_token = tf.constant(bytes(vocab.PAD_TOKEN, encoding='utf8'),
                                dtype=tf.string)
        pad_token_id = vocab_lookup.lookup(pad_token)

        dataset = neural_editor.input_fn(d_fn, vocab_lookup, BATCH_SIZE,
                                         NUM_EPOCH)
        iter = dataset.make_initializable_iterator()
        (src, tgt, inw, dlw), _ = iter.get_next()

        src_len = sequence.length_pre_embedding(src)

        tgt_len = sequence.length_pre_embedding(tgt)

        dec_inputs = decoder.prepare_decoder_inputs(tgt, start_token_id)
        dec_outputs = decoder.prepare_decoder_output(tgt, tgt_len,
                                                     stop_token_id,
                                                     pad_token_id)

        dec_inputs_len = sequence.length_pre_embedding(dec_inputs)
        dec_outputs_len = sequence.length_pre_embedding(dec_outputs)

        batch_size = tf.shape(src)[0]
        edit_vector = edit_encoder.random_noise_encoder(
            batch_size, EDIT_DIM, 14.0)

        embedding = tf.get_variable(
            'embeddings',
            shape=embed_matrix.shape,
            initializer=tf.constant_initializer(embed_matrix))

        src_embd = tf.nn.embedding_lookup(embedding, src)
        src_sent_embeds, final_states = encoder.source_sent_encoder(
            src_embd, src_len, 20, 3, 0.8)

        agn = agenda.linear(final_states, edit_vector, 4)

        dec_out = decoder.train_decoder(agn, embedding, dec_inputs,
                                        src_sent_embeds,
                                        tf.nn.embedding_lookup(embedding, inw),
                                        tf.nn.embedding_lookup(embedding, dlw),
                                        dec_inputs_len, src_len,
                                        sequence.length_pre_embedding(inw),
                                        sequence.length_pre_embedding(dlw), 5,
                                        20, 3, False)

        # eval_dec_out = decoder.greedy_eval_decoder(
        #     agn, embedding,
        #     start_token_id, stop_token_id,
        #     src_sent_embeds,
        #     tf.nn.embedding_lookup(embedding, inw),
        #     tf.nn.embedding_lookup(embedding, dlw),
        #     src_len, sequence.length_pre_embedding(inw), sequence.length_pre_embedding(dlw),
        #     5, 20, 3, 40
        # )

        eval_dec_out = decoder.beam_eval_decoder(
            agn, embedding, start_token_id, stop_token_id, src_sent_embeds,
            tf.nn.embedding_lookup(embedding, inw),
            tf.nn.embedding_lookup(embedding, dlw), src_len,
            sequence.length_pre_embedding(inw),
            sequence.length_pre_embedding(dlw), 5, 20, 3, 40)

        # saver = tf.train.Saver(write_version=tf.train.SaverDef.V1)
        # s = tf.summary.FileWriter('data/an')
        # s.add_graph(g)
        #
        # all_print = tf.get_collection('print')

        an, final_states, len = dec_out
        stacked = decoder.attention_score(dec_out)

        with tf.Session() as sess:
            sess.run([
                tf.global_variables_initializer(),
                tf.local_variables_initializer(),
                tf.tables_initializer()
            ])
            sess.run(iter.initializer)

            print(sess.run([eval_dec_out]))