示例#1
0
def residual_decoder(agenda,
                     dec_inputs,
                     dec_input_lengths,
                     hidden_dim,
                     num_layer,
                     swap_memory,
                     enable_dropout=False,
                     dropout_keep=1.,
                     name=None):
    with tf.variable_scope(name, 'residual_decoder', []):
        batch_size = tf.shape(dec_inputs)[0]
        embeddings = vocab.get_embeddings()

        # Concatenate agenda [y_hat;base_input_embed] with decoder inputs

        # [batch x max_len x word_dim]
        dec_inputs = tf.nn.embedding_lookup(embeddings, dec_inputs)
        max_len = tf.shape(dec_inputs)[1]

        # [batch x 1 x agenda_dim]
        agenda = tf.expand_dims(agenda, axis=1)

        # [batch x max_len x agenda_dim]
        agenda = tf.tile(agenda, [1, max_len, 1])

        # [batch x max_len x word_dim+agenda_dim]
        dec_inputs = tf.concat([dec_inputs, agenda], axis=2)

        helper = seq2seq.TrainingHelper(dec_inputs,
                                        dec_input_lengths,
                                        name='train_helper')
        cell = tf_rnn.MultiRNNCell([
            create_rnn_layer(i, hidden_dim // 2, enable_dropout, dropout_keep)
            for i in range(num_layer)
        ])
        zero_states = create_trainable_initial_states(batch_size, cell)

        output_layer = DecoderOutputLayer(embeddings)
        decoder = seq2seq.BasicDecoder(cell, helper, zero_states, output_layer)

        outputs, state, length = seq2seq.dynamic_decode(
            decoder, swap_memory=swap_memory)

        return outputs, state, length
示例#2
0
def prepare_output_embed(
    decoder_output,
    temperature_starter,
    decay_rate,
    decay_steps,
):
    # [VOCAB x word_dim]
    embeddings = vocab.get_embeddings()

    # Extend embedding matrix to support oov tokens
    unk_id = vocab.get_token_id(vocab.UNKNOWN_TOKEN)
    unk_embed = tf.expand_dims(vocab.embed_tokens(unk_id), 0)
    unk_embeddings = tf.tile(unk_embed, [50, 1])

    # [VOCAB+50 x word_dim]
    embeddings_extended = tf.concat([embeddings, unk_embeddings], axis=0)

    global_step = tf.train.get_global_step()
    temperature = tf.train.exponential_decay(temperature_starter,
                                             global_step,
                                             decay_steps,
                                             decay_rate,
                                             name='temperature')
    tf.summary.scalar('temper', temperature, ['extra'])

    # [batch x max_len x VOCAB+50], softmax probabilities
    outputs = decoder.rnn_output(decoder_output)

    # substitute values less than 0 for numerical stability
    outputs = tf.where(tf.less_equal(outputs, 0),
                       tf.ones_like(outputs) * 1e-10, outputs)

    # convert softmax probabilities to one_hot vectors
    dist = tfd.RelaxedOneHotCategorical(temperature, probs=outputs)

    # [batch x max_len x VOCAB+50], one_hot
    outputs_one_hot = dist.sample()

    # [batch x max_len x word_dim], one_hot^T * embedding_matrix
    outputs_embed = tf.einsum("btv,vd-> btd", outputs_one_hot,
                              embeddings_extended)

    return outputs_embed
示例#3
0
def create_decoder_cell(agenda,
                        extended_base_words,
                        oov,
                        base_sent_hiddens,
                        mev_st,
                        mev_ts,
                        base_length,
                        iw_length,
                        dw_length,
                        vocab_size,
                        attn_dim,
                        hidden_dim,
                        num_layer,
                        enable_alignment_history=False,
                        enable_dropout=False,
                        dropout_keep=1.,
                        no_insert_delete_attn=False,
                        beam_width=None):
    base_attn = seq2seq.BahdanauAttention(attn_dim,
                                          base_sent_hiddens,
                                          base_length,
                                          name='base_attn')

    cnx_src, micro_evs_st = mev_st
    mev_st_attn = seq2seq.BahdanauAttention(attn_dim,
                                            cnx_src,
                                            iw_length,
                                            name='mev_st_attn')
    mev_st_attn._values = micro_evs_st

    attns = [base_attn, mev_st_attn]

    if not no_insert_delete_attn:
        cnx_tgt, micro_evs_ts = mev_ts
        mev_ts_attn = seq2seq.BahdanauAttention(attn_dim,
                                                cnx_tgt,
                                                dw_length,
                                                name='mev_ts_attn')
        mev_ts_attn._values = micro_evs_ts

        attns += [mev_ts_attn]

    is_training = tf.get_collection('is_training')[0]
    enable_alignment_history = not is_training

    bottom_cell = tf_rnn.LSTMCell(hidden_dim, name='bottom_cell')
    bottom_attn_cell = seq2seq.AttentionWrapper(
        bottom_cell,
        tuple(attns),
        alignment_history=enable_alignment_history,
        output_attention=False,
        name='att_bottom_cell')

    all_cells = [bottom_attn_cell]

    num_layer -= 1
    for i in range(num_layer):
        cell = tf_rnn.LSTMCell(hidden_dim, name='layer_%s' % (i + 1))
        if enable_dropout and dropout_keep < 1.:
            cell = tf_rnn.DropoutWrapper(cell, output_keep_prob=dropout_keep)

        all_cells.append(cell)

    decoder_cell = AttentionAugmentRNNCell(all_cells)
    decoder_cell.set_agenda(agenda)
    decoder_cell.set_source_attn_index(0)

    output_layer = DecoderOutputLayer(vocab.get_embeddings())

    pg_cell = PointerGeneratorWrapper(decoder_cell,
                                      extended_base_words,
                                      50,
                                      output_layer,
                                      vocab_size,
                                      decoder_cell.get_source_attention,
                                      name='PointerGeneratorWrapper')

    if beam_width:
        true_batch_size = tf.cast(
            tf.shape(base_sent_hiddens)[0] / beam_width, tf.int32)
    else:
        true_batch_size = tf.shape(base_sent_hiddens)[0]

    zero_state = create_trainable_zero_state(decoder_cell,
                                             true_batch_size,
                                             beam_width=beam_width)

    return pg_cell, zero_state
示例#4
0
def editor_train(base_words,
                 source_words,
                 target_words,
                 insert_words,
                 delete_words,
                 embed_matrix,
                 vocab_table,
                 hidden_dim,
                 agenda_dim,
                 edit_dim,
                 num_encoder_layers,
                 num_decoder_layers,
                 attn_dim,
                 beam_width,
                 max_sent_length,
                 dropout_keep,
                 lamb_reg,
                 norm_eps,
                 norm_max,
                 kill_edit,
                 draw_edit,
                 swap_memory,
                 use_beam_decoder=False,
                 use_dropout=False,
                 no_insert_delete_attn=False,
                 enable_vae=True):
    batch_size = tf.shape(source_words)[0]

    # [batch]
    base_len = seq.length_pre_embedding(base_words)
    src_len = seq.length_pre_embedding(source_words)
    tgt_len = seq.length_pre_embedding(target_words)
    iw_len = seq.length_pre_embedding(insert_words)
    dw_len = seq.length_pre_embedding(delete_words)

    # variable of shape [vocab_size, embed_dim]
    embeddings = vocab.get_embeddings()

    # [batch x max_len x embed_dim]
    base_word_embeds = vocab.embed_tokens(base_words)
    # src_word_embeds = vocab.embed_tokens(source_words)
    # tgt_word_embeds = vocab.embed_tokens(target_words)
    insert_word_embeds = vocab.embed_tokens(insert_words)
    delete_word_embeds = vocab.embed_tokens(delete_words)

    # [batch x max_len x rnn_out_dim], [batch x rnn_out_dim]
    base_sent_hidden_states, base_sent_embed = encoder.source_sent_encoder(
        base_word_embeds,
        base_len,
        hidden_dim,
        num_encoder_layers,
        use_dropout=use_dropout,
        dropout_keep=dropout_keep,
        swap_memory=swap_memory)

    # [batch x edit_dim]
    if kill_edit:
        edit_vector = tf.zeros(shape=(batch_size, edit_dim))
    else:
        if draw_edit:
            edit_vector = edit_encoder.random_noise_encoder(
                batch_size, edit_dim, norm_max)
        else:
            edit_vector = edit_encoder.accumulator_encoder(
                insert_word_embeds,
                delete_word_embeds,
                iw_len,
                dw_len,
                edit_dim,
                lamb_reg,
                norm_eps,
                norm_max,
                dropout_keep,
                enable_vae=enable_vae)

    # [batch x agenda_dim]
    input_agenda = agn.linear(base_sent_embed, edit_vector, agenda_dim)

    train_dec_inp, train_dec_inp_len, \
    train_dec_out, train_dec_out_len = prepare_decoder_input_output(target_words, tgt_len, vocab_table)

    train_decoder = decoder.train_decoder(
        input_agenda,
        embeddings,
        train_dec_inp,
        base_sent_hidden_states,
        insert_word_embeds,
        delete_word_embeds,
        train_dec_inp_len,
        base_len,
        iw_len,
        dw_len,
        attn_dim,
        hidden_dim,
        num_decoder_layers,
        swap_memory,
        enable_dropout=use_dropout,
        dropout_keep=dropout_keep,
        no_insert_delete_attn=no_insert_delete_attn)

    if use_beam_decoder:
        infr_decoder = decoder.beam_eval_decoder(
            input_agenda,
            embeddings,
            vocab.get_token_id(vocab.START_TOKEN, vocab_table),
            vocab.get_token_id(vocab.STOP_TOKEN, vocab_table),
            base_sent_hidden_states,
            insert_word_embeds,
            delete_word_embeds,
            base_len,
            iw_len,
            dw_len,
            attn_dim,
            hidden_dim,
            num_decoder_layers,
            max_sent_length,
            beam_width,
            swap_memory,
            enable_dropout=use_dropout,
            dropout_keep=dropout_keep,
            no_insert_delete_attn=no_insert_delete_attn)
    else:
        infr_decoder = decoder.greedy_eval_decoder(
            input_agenda,
            embeddings,
            vocab.get_token_id(vocab.START_TOKEN, vocab_table),
            vocab.get_token_id(vocab.STOP_TOKEN, vocab_table),
            base_sent_hidden_states,
            insert_word_embeds,
            delete_word_embeds,
            base_len,
            iw_len,
            dw_len,
            attn_dim,
            hidden_dim,
            num_decoder_layers,
            max_sent_length,
            swap_memory,
            enable_dropout=use_dropout,
            dropout_keep=dropout_keep,
            no_insert_delete_attn=no_insert_delete_attn)

    return train_decoder, infr_decoder, train_dec_out, train_dec_out_len
示例#5
0
def editor_train(base_words,
                 extended_base_words,
                 output_words,
                 extended_output_words,
                 source_words,
                 target_words,
                 insert_words,
                 delete_words,
                 oov,
                 vocab_size,
                 hidden_dim,
                 agenda_dim,
                 edit_dim,
                 micro_edit_ev_dim,
                 num_heads,
                 num_encoder_layers,
                 num_decoder_layers,
                 attn_dim,
                 beam_width,
                 ctx_hidden_dim,
                 ctx_hidden_layer,
                 wa_hidden_dim,
                 wa_hidden_layer,
                 meve_hidden_dim,
                 meve_hidden_layers,
                 recons_dense_layers,
                 max_sent_length,
                 dropout_keep,
                 lamb_reg,
                 norm_eps,
                 norm_max,
                 kill_edit,
                 draw_edit,
                 swap_memory,
                 use_beam_decoder=False,
                 use_dropout=False,
                 no_insert_delete_attn=False,
                 enable_vae=True):
    batch_size = tf.shape(source_words)[0]

    # [batch]
    base_len = seq.length_pre_embedding(base_words)
    output_len = seq.length_pre_embedding(extended_output_words)
    src_len = seq.length_pre_embedding(source_words)
    tgt_len = seq.length_pre_embedding(target_words)
    iw_len = seq.length_pre_embedding(insert_words)
    dw_len = seq.length_pre_embedding(delete_words)

    # variable of shape [vocab_size, embed_dim]
    embeddings = vocab.get_embeddings()

    # [batch x max_len x embed_dim]
    base_word_embeds = vocab.embed_tokens(base_words)
    output_word_embeds = vocab.embed_tokens(output_words)
    # src_word_embeds = vocab.embed_tokens(source_words)
    # tgt_word_embeds = vocab.embed_tokens(target_words)
    insert_word_embeds = vocab.embed_tokens(insert_words)
    delete_word_embeds = vocab.embed_tokens(delete_words)

    # [batch x max_len x rnn_out_dim], [batch x rnn_out_dim]
    base_sent_hidden_states, base_sent_embed = encoder.source_sent_encoder(
        base_word_embeds,
        base_len,
        hidden_dim,
        num_encoder_layers,
        use_dropout=use_dropout,
        dropout_keep=dropout_keep,
        swap_memory=swap_memory)

    assert kill_edit == False and draw_edit == False

    # [batch x edit_dim]
    if kill_edit:
        edit_vector = tf.zeros(shape=(batch_size, edit_dim))
    else:
        if draw_edit:
            edit_vector = random_noise_encoder(batch_size, edit_dim, norm_max)
        else:
            edit_vector = accumulator_encoder(insert_word_embeds,
                                              delete_word_embeds,
                                              iw_len,
                                              dw_len,
                                              edit_dim,
                                              lamb_reg,
                                              norm_eps,
                                              norm_max,
                                              dropout_keep,
                                              enable_vae=enable_vae)

            wa_inserted, wa_deleted = (tf.constant([[0]]), tf.constant(
                [[0]])), (tf.constant([[0]]), tf.constant([[0]]))

    # [batch x agenda_dim]
    base_agenda = agn.linear(base_sent_embed, edit_vector, agenda_dim)

    train_dec_inp, train_dec_inp_len, \
    train_dec_out, train_dec_out_len = prepare_decoder_input_output(output_words, extended_output_words, output_len)

    train_dec_inp_extended = prepare_decoder_inputs(extended_output_words,
                                                    tf.cast(-1, tf.int64))

    train_decoder = decoder.train_decoder(
        base_agenda,
        embeddings,
        extended_base_words,
        oov,
        train_dec_inp,
        train_dec_inp_extended,
        base_sent_hidden_states,
        wa_inserted,
        wa_deleted,
        train_dec_inp_len,
        base_len,
        src_len,
        tgt_len,
        vocab_size,
        attn_dim,
        hidden_dim,
        num_decoder_layers,
        swap_memory,
        enable_dropout=use_dropout,
        dropout_keep=dropout_keep,
        no_insert_delete_attn=no_insert_delete_attn)

    if use_beam_decoder:
        infr_decoder = decoder.beam_eval_decoder(
            base_agenda,
            embeddings,
            extended_base_words,
            oov,
            vocab.get_token_id(vocab.START_TOKEN),
            vocab.get_token_id(vocab.STOP_TOKEN),
            base_sent_hidden_states,
            wa_inserted,
            wa_deleted,
            base_len,
            src_len,
            tgt_len,
            vocab_size,
            attn_dim,
            hidden_dim,
            num_decoder_layers,
            max_sent_length,
            beam_width,
            swap_memory,
            enable_dropout=use_dropout,
            dropout_keep=dropout_keep,
            no_insert_delete_attn=no_insert_delete_attn)
    else:
        infr_decoder = decoder.greedy_eval_decoder(
            base_agenda,
            embeddings,
            extended_base_words,
            oov,
            vocab.get_token_id(vocab.START_TOKEN),
            vocab.get_token_id(vocab.STOP_TOKEN),
            base_sent_hidden_states,
            wa_inserted,
            wa_deleted,
            base_len,
            src_len,
            tgt_len,
            vocab_size,
            attn_dim,
            hidden_dim,
            num_decoder_layers,
            max_sent_length,
            swap_memory,
            enable_dropout=use_dropout,
            dropout_keep=dropout_keep,
            no_insert_delete_attn=no_insert_delete_attn)

    edit_vector_recons = output_words_to_edit_vector(
        output_word_embeds, output_len, edit_dim, ctx_hidden_dim,
        ctx_hidden_layer, recons_dense_layers, swap_memory)
    optimizer.add_reconst_loss(edit_vector, edit_vector_recons)

    return train_decoder, infr_decoder, train_dec_out, train_dec_out_len
示例#6
0
def attn_encoder(source_words,
                 target_words,
                 insert_words,
                 delete_words,
                 source_lengths,
                 target_lengths,
                 iw_lengths,
                 dw_lengths,
                 transformer_params,
                 wa_hidden_dim,
                 meve_hidden_dim,
                 meve_hidden_layers,
                 edit_dim,
                 micro_edit_ev_dim,
                 noise_scaler,
                 norm_eps,
                 norm_max,
                 dropout_keep=1.,
                 use_dropout=False,
                 swap_memory=False,
                 enable_vae=True):
    """
    Args:
        source_words:
        target_words:
        insert_words:
        delete_words:
        source_lengths:
        target_lengths:
        iw_lengths:
        dw_lengths:
        ctx_hidden_dim:
        ctx_hidden_layer:
        wa_hidden_dim:
        wa_hidden_layer:
        edit_dim:
        noise_scaler:
        norm_eps:
        norm_max:
        dropout_keep:

    Returns:

    """
    with tf.variable_scope(OPS_NAME):
        wa_inserted_last, wa_deleted_last = wa_accumulator(
            insert_words, delete_words, iw_lengths, dw_lengths, wa_hidden_dim)

        if use_dropout and dropout_keep < 1.:
            wa_inserted_last = tf.nn.dropout(wa_inserted_last, dropout_keep)
            wa_deleted_last = tf.nn.dropout(wa_deleted_last, dropout_keep)

        embedding_matrix = vocab.get_embeddings()
        embedding_layer = ConcatPosEmbedding(
            transformer_params.hidden_size, embedding_matrix,
            transformer_params.pos_encoding_dim)
        micro_ev_projection = tf.layers.Dense(micro_edit_ev_dim,
                                              activation=None,
                                              use_bias=True,
                                              name='micro_ev_proj')
        mev_extractor = TransformerMicroEditExtractor(embedding_layer,
                                                      micro_ev_projection,
                                                      transformer_params)

        cnx_tgt, micro_evs_st = mev_extractor(source_words, target_words,
                                              source_lengths, target_lengths)
        cnx_src, micro_evs_ts = mev_extractor(target_words, source_words,
                                              target_lengths, source_lengths)

        micro_ev_encoder = tf.make_template('micro_ev_encoder',
                                            context_encoder,
                                            hidden_dim=meve_hidden_dim,
                                            num_layers=meve_hidden_layers,
                                            swap_memory=swap_memory,
                                            use_dropout=use_dropout,
                                            dropout_keep=dropout_keep)

        aggreg_mev_st = micro_ev_encoder(micro_evs_st, source_lengths)
        aggreg_mev_ts = micro_ev_encoder(micro_evs_ts, target_lengths)

        aggreg_mev_st_last = sequence.last_relevant(aggreg_mev_st,
                                                    source_lengths)
        aggreg_mev_ts_last = sequence.last_relevant(aggreg_mev_ts,
                                                    target_lengths)

        if use_dropout and dropout_keep < 1.:
            aggreg_mev_st_last = tf.nn.dropout(aggreg_mev_st_last,
                                               dropout_keep)
            aggreg_mev_ts_last = tf.nn.dropout(aggreg_mev_ts_last,
                                               dropout_keep)

        features = tf.concat([
            aggreg_mev_st_last, aggreg_mev_ts_last, wa_inserted_last,
            wa_deleted_last
        ],
                             axis=1)

        edit_vector = tf.layers.dense(features,
                                      edit_dim,
                                      use_bias=False,
                                      name='encoder_ev')

        if enable_vae:
            edit_vector = sample_vMF(edit_vector, noise_scaler, norm_eps,
                                     norm_max)

        return edit_vector, (cnx_src, micro_evs_st), (cnx_tgt, micro_evs_ts)
示例#7
0
 def init_from_embedding_matrix():
     embedding_matrix = vocab.get_embeddings()
     embed_layer = EmbeddingSharedWeights(embedding_matrix)
     tf.add_to_collection('embed_layer', embed_layer)
示例#8
0
def editor_train(base_words,
                 output_words,
                 source_words,
                 target_words,
                 insert_words,
                 delete_words,
                 hidden_dim,
                 agenda_dim,
                 edit_dim,
                 micro_edit_ev_dim,
                 num_heads,
                 num_encoder_layers,
                 num_decoder_layers,
                 attn_dim,
                 beam_width,
                 ctx_hidden_dim,
                 ctx_hidden_layer,
                 wa_hidden_dim,
                 wa_hidden_layer,
                 meve_hidden_dim,
                 meve_hidden_layers,
                 max_sent_length,
                 dropout_keep,
                 lamb_reg,
                 norm_eps,
                 norm_max,
                 kill_edit,
                 draw_edit,
                 swap_memory,
                 use_beam_decoder=False,
                 use_dropout=False,
                 no_insert_delete_attn=False,
                 enable_vae=True):
    batch_size = tf.shape(source_words)[0]

    # [batch]
    base_len = seq.length_pre_embedding(base_words)
    output_len = seq.length_pre_embedding(output_words)
    src_len = seq.length_pre_embedding(source_words)
    tgt_len = seq.length_pre_embedding(target_words)
    iw_len = seq.length_pre_embedding(insert_words)
    dw_len = seq.length_pre_embedding(delete_words)

    # variable of shape [vocab_size, embed_dim]
    embeddings = vocab.get_embeddings()

    # [batch x max_len x embed_dim]
    base_word_embeds = vocab.embed_tokens(base_words)
    src_word_embeds = vocab.embed_tokens(source_words)
    tgt_word_embeds = vocab.embed_tokens(target_words)
    insert_word_embeds = vocab.embed_tokens(insert_words)
    delete_word_embeds = vocab.embed_tokens(delete_words)

    sent_encoder = tf.make_template('sent_encoder',
                                    encoder.source_sent_encoder,
                                    hidden_dim=hidden_dim,
                                    num_layer=num_encoder_layers,
                                    swap_memory=swap_memory,
                                    use_dropout=use_dropout,
                                    dropout_keep=dropout_keep)

    # [batch x max_len x rnn_out_dim], [batch x rnn_out_dim]
    base_sent_hidden_states, base_sent_embed = sent_encoder(
        base_word_embeds, base_len)

    assert kill_edit == False and draw_edit == False

    # [batch x edit_dim]
    if kill_edit:
        edit_vector = tf.zeros(shape=(batch_size, edit_dim))
    else:
        if draw_edit:
            edit_vector = random_noise_encoder(batch_size, edit_dim, norm_max)
        else:
            edit_vector, wa_inserted, wa_deleted = attn_encoder(
                src_word_embeds,
                tgt_word_embeds,
                insert_word_embeds,
                delete_word_embeds,
                src_len,
                tgt_len,
                iw_len,
                dw_len,
                ctx_hidden_dim,
                ctx_hidden_layer,
                wa_hidden_dim,
                wa_hidden_layer,
                meve_hidden_dim,
                meve_hidden_layers,
                edit_dim,
                micro_edit_ev_dim,
                num_heads,
                lamb_reg,
                norm_eps,
                norm_max,
                sent_encoder,
                use_dropout=use_dropout,
                dropout_keep=dropout_keep,
                swap_memory=swap_memory,
                enable_vae=enable_vae)

    # [batch x agenda_dim]
    input_agenda = agn.linear(base_sent_embed, edit_vector, agenda_dim)

    train_dec_inp, train_dec_inp_len, \
    train_dec_out, train_dec_out_len = prepare_decoder_input_output(output_words, output_len, None)

    train_decoder = decoder.train_decoder(
        input_agenda,
        embeddings,
        train_dec_inp,
        base_sent_hidden_states,
        wa_inserted,
        wa_deleted,
        train_dec_inp_len,
        base_len,
        iw_len,
        dw_len,
        attn_dim,
        hidden_dim,
        num_decoder_layers,
        swap_memory,
        enable_dropout=use_dropout,
        dropout_keep=dropout_keep,
        no_insert_delete_attn=no_insert_delete_attn)

    if use_beam_decoder:
        infr_decoder = decoder.beam_eval_decoder(
            input_agenda,
            embeddings,
            vocab.get_token_id(vocab.START_TOKEN),
            vocab.get_token_id(vocab.STOP_TOKEN),
            base_sent_hidden_states,
            wa_inserted,
            wa_deleted,
            base_len,
            iw_len,
            dw_len,
            attn_dim,
            hidden_dim,
            num_decoder_layers,
            max_sent_length,
            beam_width,
            swap_memory,
            enable_dropout=use_dropout,
            dropout_keep=dropout_keep,
            no_insert_delete_attn=no_insert_delete_attn)
    else:
        infr_decoder = decoder.greedy_eval_decoder(
            input_agenda,
            embeddings,
            vocab.get_token_id(vocab.START_TOKEN),
            vocab.get_token_id(vocab.STOP_TOKEN),
            base_sent_hidden_states,
            wa_inserted,
            wa_deleted,
            base_len,
            iw_len,
            dw_len,
            attn_dim,
            hidden_dim,
            num_decoder_layers,
            max_sent_length,
            swap_memory,
            enable_dropout=use_dropout,
            dropout_keep=dropout_keep,
            no_insert_delete_attn=no_insert_delete_attn)

    return train_decoder, infr_decoder, train_dec_out, train_dec_out_len
示例#9
0
def editor_train(base_words,
                 extended_base_words,
                 output_words,
                 extended_output_words,
                 source_words,
                 target_words,
                 insert_words,
                 delete_words,
                 oov,
                 vocab_size,
                 hidden_dim,
                 agenda_dim,
                 edit_dim,
                 micro_edit_ev_dim,
                 num_heads,
                 num_encoder_layers,
                 num_decoder_layers,
                 attn_dim,
                 beam_width,
                 ctx_hidden_dim,
                 ctx_hidden_layer,
                 wa_hidden_dim,
                 wa_hidden_layer,
                 meve_hidden_dim,
                 meve_hidden_layers,
                 max_sent_length,
                 dropout_keep,
                 lamb_reg,
                 norm_eps,
                 norm_max,
                 kill_edit,
                 draw_edit,
                 swap_memory,
                 use_beam_decoder=False,
                 use_dropout=False,
                 no_insert_delete_attn=False,
                 enable_vae=True):
    # [batch]
    base_len = seq.length_pre_embedding(base_words)
    output_len = seq.length_pre_embedding(extended_output_words)

    # variable of shape [vocab_size, embed_dim]
    embeddings = vocab.get_embeddings()

    # [batch x max_len x embed_dim]
    base_word_embeds = vocab.embed_tokens(base_words)

    # [batch x max_len x rnn_out_dim], [batch x rnn_out_dim]
    base_sent_hidden_states, base_sent_embed = encoder.source_sent_encoder(
        base_word_embeds,
        base_len,
        hidden_dim,
        num_encoder_layers,
        use_dropout=use_dropout,
        dropout_keep=dropout_keep,
        swap_memory=swap_memory)

    assert kill_edit == False and draw_edit == False

    # [batch x agenda_dim]
    base_agenda = linear(base_sent_embed, agenda_dim)

    train_dec_inp, train_dec_inp_len, \
    train_dec_out, train_dec_out_len = prepare_decoder_input_output(output_words, extended_output_words, output_len)

    train_dec_inp_extended = prepare_decoder_inputs(extended_output_words,
                                                    tf.cast(-1, tf.int64))

    train_decoder = decoder.train_decoder(
        base_agenda,
        embeddings,
        extended_base_words,
        oov,
        train_dec_inp,
        train_dec_inp_extended,
        base_sent_hidden_states,
        train_dec_inp_len,
        base_len,
        vocab_size,
        attn_dim,
        hidden_dim,
        num_decoder_layers,
        swap_memory,
        enable_dropout=use_dropout,
        dropout_keep=dropout_keep,
        no_insert_delete_attn=no_insert_delete_attn)

    if use_beam_decoder:
        infr_decoder = decoder.beam_eval_decoder(
            base_agenda,
            embeddings,
            extended_base_words,
            oov,
            vocab.get_token_id(vocab.START_TOKEN),
            vocab.get_token_id(vocab.STOP_TOKEN),
            base_sent_hidden_states,
            base_len,
            vocab_size,
            attn_dim,
            hidden_dim,
            num_decoder_layers,
            max_sent_length,
            beam_width,
            swap_memory,
            enable_dropout=use_dropout,
            dropout_keep=dropout_keep,
            no_insert_delete_attn=no_insert_delete_attn)
    else:
        infr_decoder = decoder.greedy_eval_decoder(
            base_agenda,
            embeddings,
            extended_base_words,
            oov,
            vocab.get_token_id(vocab.START_TOKEN),
            vocab.get_token_id(vocab.STOP_TOKEN),
            base_sent_hidden_states,
            base_len,
            vocab_size,
            attn_dim,
            hidden_dim,
            num_decoder_layers,
            max_sent_length,
            swap_memory,
            enable_dropout=use_dropout,
            dropout_keep=dropout_keep,
            no_insert_delete_attn=no_insert_delete_attn)

        add_decoder_attn_history_graph(infr_decoder)

    return train_decoder, infr_decoder, train_dec_out, train_dec_out_len
def decoder_outputs_to_edit_vector(decoder_output, temperature_starter,
                                   decay_rate, decay_steps, edit_dim,
                                   enc_hidden_dim, enc_num_layers,
                                   dense_layers, swap_memory):
    with tf.variable_scope(OPS_NAME):
        # [VOCAB x word_dim]
        embeddings = vocab.get_embeddings()

        # Extend embedding matrix to support oov tokens
        unk_id = vocab.get_token_id(vocab.UNKNOWN_TOKEN)
        unk_embed = tf.expand_dims(vocab.embed_tokens(unk_id), 0)
        unk_embeddings = tf.tile(unk_embed, [50, 1])

        # [VOCAB+50 x word_dim]
        embeddings_extended = tf.concat([embeddings, unk_embeddings], axis=0)

        global_step = tf.train.get_global_step()
        temperature = tf.train.exponential_decay(temperature_starter,
                                                 global_step,
                                                 decay_steps,
                                                 decay_rate,
                                                 name='temperature')
        tf.summary.scalar('temper', temperature, ['extra'])

        # [batch x max_len x VOCAB+50], softmax probabilities
        outputs = decoder.rnn_output(decoder_output)

        # substitute values less than 0 for numerical stability
        outputs = tf.where(tf.less_equal(outputs, 0),
                           tf.ones_like(outputs) * 1e-10, outputs)

        # convert softmax probabilities to one_hot vectors
        dist = tfd.RelaxedOneHotCategorical(temperature, probs=outputs)

        # [batch x max_len x VOCAB+50], one_hot
        outputs_one_hot = dist.sample()

        # [batch x max_len x word_dim], one_hot^T * embedding_matrix
        outputs_embed = tf.einsum("btv,vd-> btd", outputs_one_hot,
                                  embeddings_extended)

        # [batch]
        outputs_length = decoder.seq_length(decoder_output)

        # [batch x max_len x hidden], [batch x hidden]
        hidden_states, sentence_embedding = encoder.source_sent_encoder(
            outputs_embed,
            outputs_length,
            enc_hidden_dim,
            enc_num_layers,
            use_dropout=False,
            dropout_keep=1.0,
            swap_memory=swap_memory)

        h = sentence_embedding
        for l in dense_layers:
            h = tf.layers.dense(h,
                                l,
                                activation='relu',
                                name='hidden_%s' % (l))

        # [batch x edit_dim]
        edit_vector = tf.layers.dense(h,
                                      edit_dim,
                                      activation=None,
                                      name='linear')

    return edit_vector