Exemple #1
0
def train_decoder(agenda,
                  embeddings,
                  extended_base_words,
                  oov,
                  dec_inputs,
                  dec_extended_inputs,
                  base_sent_hiddens,
                  insert_word_embeds,
                  delete_word_embeds,
                  dec_input_lengths,
                  base_length,
                  iw_length,
                  dw_length,
                  vocab_size,
                  attn_dim,
                  hidden_dim,
                  num_layer,
                  swap_memory,
                  enable_dropout=False,
                  dropout_keep=1.,
                  no_insert_delete_attn=False):
    with tf.variable_scope(OPS_NAME, 'decoder'):
        dec_input_embeds = vocab.embed_tokens(dec_inputs)
        last_ids = tf.cast(tf.expand_dims(dec_extended_inputs, 2), tf.float32)
        cell_input = tf.concat([dec_input_embeds, last_ids], axis=2)

        helper = seq2seq.TrainingHelper(cell_input,
                                        dec_input_lengths,
                                        name='train_helper')

        cell, zero_states = create_decoder_cell(
            agenda,
            extended_base_words,
            oov,
            base_sent_hiddens,
            insert_word_embeds,
            delete_word_embeds,
            base_length,
            iw_length,
            dw_length,
            vocab_size,
            attn_dim,
            hidden_dim,
            num_layer,
            enable_dropout=enable_dropout,
            dropout_keep=dropout_keep,
            no_insert_delete_attn=no_insert_delete_attn)

        decoder = seq2seq.BasicDecoder(cell, helper, zero_states)
        outputs, state, length = seq2seq.dynamic_decode(
            decoder, swap_memory=swap_memory)

        return outputs, state, length
Exemple #2
0
    def fn(orig_ids):
        orig_ids = tf.cast(orig_ids, tf.int64)

        in_vocab_ids = tf.where(tf.less(orig_ids, vocab_size), orig_ids,
                                tf.ones_like(orig_ids) * vocab.OOV_TOKEN_ID)
        embeds = vocab.embed_tokens(in_vocab_ids)

        last_ids = tf.where(
            tf.equal(orig_ids, vocab.get_token_id(vocab.START_TOKEN)),
            tf.ones_like(orig_ids) * -1, orig_ids)
        last_ids = tf.cast(tf.expand_dims(last_ids, 2), tf.float32)

        cell_input = tf.concat([embeds, last_ids], axis=2)

        return cell_input
Exemple #3
0
def prepare_output_embed(
    decoder_output,
    temperature_starter,
    decay_rate,
    decay_steps,
):
    # [VOCAB x word_dim]
    embeddings = vocab.get_embeddings()

    # Extend embedding matrix to support oov tokens
    unk_id = vocab.get_token_id(vocab.UNKNOWN_TOKEN)
    unk_embed = tf.expand_dims(vocab.embed_tokens(unk_id), 0)
    unk_embeddings = tf.tile(unk_embed, [50, 1])

    # [VOCAB+50 x word_dim]
    embeddings_extended = tf.concat([embeddings, unk_embeddings], axis=0)

    global_step = tf.train.get_global_step()
    temperature = tf.train.exponential_decay(temperature_starter,
                                             global_step,
                                             decay_steps,
                                             decay_rate,
                                             name='temperature')
    tf.summary.scalar('temper', temperature, ['extra'])

    # [batch x max_len x VOCAB+50], softmax probabilities
    outputs = decoder.rnn_output(decoder_output)

    # substitute values less than 0 for numerical stability
    outputs = tf.where(tf.less_equal(outputs, 0),
                       tf.ones_like(outputs) * 1e-10, outputs)

    # convert softmax probabilities to one_hot vectors
    dist = tfd.RelaxedOneHotCategorical(temperature, probs=outputs)

    # [batch x max_len x VOCAB+50], one_hot
    outputs_one_hot = dist.sample()

    # [batch x max_len x word_dim], one_hot^T * embedding_matrix
    outputs_embed = tf.einsum("btv,vd-> btd", outputs_one_hot,
                              embeddings_extended)

    return outputs_embed
Exemple #4
0
 def fn(orig_ids):
     orig_ids = tf.cast(orig_ids, tf.int64)
     in_vocab_ids = tf.where(tf.less(orig_ids, vocab_size), orig_ids,
                             tf.ones_like(orig_ids) * vocab.OOV_TOKEN_ID)
     embeds = vocab.embed_tokens(in_vocab_ids)
     return embeds
Exemple #5
0
def editor_train(base_words,
                 source_words,
                 target_words,
                 insert_words,
                 delete_words,
                 embed_matrix,
                 vocab_table,
                 hidden_dim,
                 agenda_dim,
                 edit_dim,
                 num_encoder_layers,
                 num_decoder_layers,
                 attn_dim,
                 beam_width,
                 max_sent_length,
                 dropout_keep,
                 lamb_reg,
                 norm_eps,
                 norm_max,
                 kill_edit,
                 draw_edit,
                 swap_memory,
                 use_beam_decoder=False,
                 use_dropout=False,
                 no_insert_delete_attn=False,
                 enable_vae=True):
    batch_size = tf.shape(source_words)[0]

    # [batch]
    base_len = seq.length_pre_embedding(base_words)
    src_len = seq.length_pre_embedding(source_words)
    tgt_len = seq.length_pre_embedding(target_words)
    iw_len = seq.length_pre_embedding(insert_words)
    dw_len = seq.length_pre_embedding(delete_words)

    # variable of shape [vocab_size, embed_dim]
    embeddings = vocab.get_embeddings()

    # [batch x max_len x embed_dim]
    base_word_embeds = vocab.embed_tokens(base_words)
    # src_word_embeds = vocab.embed_tokens(source_words)
    # tgt_word_embeds = vocab.embed_tokens(target_words)
    insert_word_embeds = vocab.embed_tokens(insert_words)
    delete_word_embeds = vocab.embed_tokens(delete_words)

    # [batch x max_len x rnn_out_dim], [batch x rnn_out_dim]
    base_sent_hidden_states, base_sent_embed = encoder.source_sent_encoder(
        base_word_embeds,
        base_len,
        hidden_dim,
        num_encoder_layers,
        use_dropout=use_dropout,
        dropout_keep=dropout_keep,
        swap_memory=swap_memory)

    # [batch x edit_dim]
    if kill_edit:
        edit_vector = tf.zeros(shape=(batch_size, edit_dim))
    else:
        if draw_edit:
            edit_vector = edit_encoder.random_noise_encoder(
                batch_size, edit_dim, norm_max)
        else:
            edit_vector = edit_encoder.accumulator_encoder(
                insert_word_embeds,
                delete_word_embeds,
                iw_len,
                dw_len,
                edit_dim,
                lamb_reg,
                norm_eps,
                norm_max,
                dropout_keep,
                enable_vae=enable_vae)

    # [batch x agenda_dim]
    input_agenda = agn.linear(base_sent_embed, edit_vector, agenda_dim)

    train_dec_inp, train_dec_inp_len, \
    train_dec_out, train_dec_out_len = prepare_decoder_input_output(target_words, tgt_len, vocab_table)

    train_decoder = decoder.train_decoder(
        input_agenda,
        embeddings,
        train_dec_inp,
        base_sent_hidden_states,
        insert_word_embeds,
        delete_word_embeds,
        train_dec_inp_len,
        base_len,
        iw_len,
        dw_len,
        attn_dim,
        hidden_dim,
        num_decoder_layers,
        swap_memory,
        enable_dropout=use_dropout,
        dropout_keep=dropout_keep,
        no_insert_delete_attn=no_insert_delete_attn)

    if use_beam_decoder:
        infr_decoder = decoder.beam_eval_decoder(
            input_agenda,
            embeddings,
            vocab.get_token_id(vocab.START_TOKEN, vocab_table),
            vocab.get_token_id(vocab.STOP_TOKEN, vocab_table),
            base_sent_hidden_states,
            insert_word_embeds,
            delete_word_embeds,
            base_len,
            iw_len,
            dw_len,
            attn_dim,
            hidden_dim,
            num_decoder_layers,
            max_sent_length,
            beam_width,
            swap_memory,
            enable_dropout=use_dropout,
            dropout_keep=dropout_keep,
            no_insert_delete_attn=no_insert_delete_attn)
    else:
        infr_decoder = decoder.greedy_eval_decoder(
            input_agenda,
            embeddings,
            vocab.get_token_id(vocab.START_TOKEN, vocab_table),
            vocab.get_token_id(vocab.STOP_TOKEN, vocab_table),
            base_sent_hidden_states,
            insert_word_embeds,
            delete_word_embeds,
            base_len,
            iw_len,
            dw_len,
            attn_dim,
            hidden_dim,
            num_decoder_layers,
            max_sent_length,
            swap_memory,
            enable_dropout=use_dropout,
            dropout_keep=dropout_keep,
            no_insert_delete_attn=no_insert_delete_attn)

    return train_decoder, infr_decoder, train_dec_out, train_dec_out_len
def editor_train(base_words,
                 extended_base_words,
                 output_words,
                 extended_output_words,
                 source_words,
                 target_words,
                 insert_words,
                 delete_words,
                 oov,
                 vocab_size,
                 hidden_dim,
                 agenda_dim,
                 edit_dim,
                 micro_edit_ev_dim,
                 num_heads,
                 num_encoder_layers,
                 num_decoder_layers,
                 attn_dim,
                 beam_width,
                 ctx_hidden_dim,
                 ctx_hidden_layer,
                 wa_hidden_dim,
                 wa_hidden_layer,
                 meve_hidden_dim,
                 meve_hidden_layers,
                 recons_dense_layers,
                 max_sent_length,
                 dropout_keep,
                 lamb_reg,
                 norm_eps,
                 norm_max,
                 kill_edit,
                 draw_edit,
                 swap_memory,
                 use_beam_decoder=False,
                 use_dropout=False,
                 no_insert_delete_attn=False,
                 enable_vae=True):
    batch_size = tf.shape(source_words)[0]

    # [batch]
    base_len = seq.length_pre_embedding(base_words)
    output_len = seq.length_pre_embedding(extended_output_words)
    src_len = seq.length_pre_embedding(source_words)
    tgt_len = seq.length_pre_embedding(target_words)
    iw_len = seq.length_pre_embedding(insert_words)
    dw_len = seq.length_pre_embedding(delete_words)

    # variable of shape [vocab_size, embed_dim]
    embeddings = vocab.get_embeddings()

    # [batch x max_len x embed_dim]
    base_word_embeds = vocab.embed_tokens(base_words)
    output_word_embeds = vocab.embed_tokens(output_words)
    # src_word_embeds = vocab.embed_tokens(source_words)
    # tgt_word_embeds = vocab.embed_tokens(target_words)
    insert_word_embeds = vocab.embed_tokens(insert_words)
    delete_word_embeds = vocab.embed_tokens(delete_words)

    # [batch x max_len x rnn_out_dim], [batch x rnn_out_dim]
    base_sent_hidden_states, base_sent_embed = encoder.source_sent_encoder(
        base_word_embeds,
        base_len,
        hidden_dim,
        num_encoder_layers,
        use_dropout=use_dropout,
        dropout_keep=dropout_keep,
        swap_memory=swap_memory)

    assert kill_edit == False and draw_edit == False

    # [batch x edit_dim]
    if kill_edit:
        edit_vector = tf.zeros(shape=(batch_size, edit_dim))
    else:
        if draw_edit:
            edit_vector = random_noise_encoder(batch_size, edit_dim, norm_max)
        else:
            edit_vector = accumulator_encoder(insert_word_embeds,
                                              delete_word_embeds,
                                              iw_len,
                                              dw_len,
                                              edit_dim,
                                              lamb_reg,
                                              norm_eps,
                                              norm_max,
                                              dropout_keep,
                                              enable_vae=enable_vae)

            wa_inserted, wa_deleted = (tf.constant([[0]]), tf.constant(
                [[0]])), (tf.constant([[0]]), tf.constant([[0]]))

    # [batch x agenda_dim]
    base_agenda = agn.linear(base_sent_embed, edit_vector, agenda_dim)

    train_dec_inp, train_dec_inp_len, \
    train_dec_out, train_dec_out_len = prepare_decoder_input_output(output_words, extended_output_words, output_len)

    train_dec_inp_extended = prepare_decoder_inputs(extended_output_words,
                                                    tf.cast(-1, tf.int64))

    train_decoder = decoder.train_decoder(
        base_agenda,
        embeddings,
        extended_base_words,
        oov,
        train_dec_inp,
        train_dec_inp_extended,
        base_sent_hidden_states,
        wa_inserted,
        wa_deleted,
        train_dec_inp_len,
        base_len,
        src_len,
        tgt_len,
        vocab_size,
        attn_dim,
        hidden_dim,
        num_decoder_layers,
        swap_memory,
        enable_dropout=use_dropout,
        dropout_keep=dropout_keep,
        no_insert_delete_attn=no_insert_delete_attn)

    if use_beam_decoder:
        infr_decoder = decoder.beam_eval_decoder(
            base_agenda,
            embeddings,
            extended_base_words,
            oov,
            vocab.get_token_id(vocab.START_TOKEN),
            vocab.get_token_id(vocab.STOP_TOKEN),
            base_sent_hidden_states,
            wa_inserted,
            wa_deleted,
            base_len,
            src_len,
            tgt_len,
            vocab_size,
            attn_dim,
            hidden_dim,
            num_decoder_layers,
            max_sent_length,
            beam_width,
            swap_memory,
            enable_dropout=use_dropout,
            dropout_keep=dropout_keep,
            no_insert_delete_attn=no_insert_delete_attn)
    else:
        infr_decoder = decoder.greedy_eval_decoder(
            base_agenda,
            embeddings,
            extended_base_words,
            oov,
            vocab.get_token_id(vocab.START_TOKEN),
            vocab.get_token_id(vocab.STOP_TOKEN),
            base_sent_hidden_states,
            wa_inserted,
            wa_deleted,
            base_len,
            src_len,
            tgt_len,
            vocab_size,
            attn_dim,
            hidden_dim,
            num_decoder_layers,
            max_sent_length,
            swap_memory,
            enable_dropout=use_dropout,
            dropout_keep=dropout_keep,
            no_insert_delete_attn=no_insert_delete_attn)

    edit_vector_recons = output_words_to_edit_vector(
        output_word_embeds, output_len, edit_dim, ctx_hidden_dim,
        ctx_hidden_layer, recons_dense_layers, swap_memory)
    optimizer.add_reconst_loss(edit_vector, edit_vector_recons)

    return train_decoder, infr_decoder, train_dec_out, train_dec_out_len
Exemple #7
0
    def _get_input_embeddings(self, ids):
        ids = tf.where(tf.less(ids, self.vocab_size), ids,
                       tf.ones_like(ids) * vocab.OOV_TOKEN_ID)

        return vocab.embed_tokens(ids)
Exemple #8
0
def editor_train(base_words,
                 output_words,
                 source_words,
                 target_words,
                 insert_words,
                 delete_words,
                 hidden_dim,
                 agenda_dim,
                 edit_dim,
                 micro_edit_ev_dim,
                 num_heads,
                 num_encoder_layers,
                 num_decoder_layers,
                 attn_dim,
                 beam_width,
                 ctx_hidden_dim,
                 ctx_hidden_layer,
                 wa_hidden_dim,
                 wa_hidden_layer,
                 meve_hidden_dim,
                 meve_hidden_layers,
                 max_sent_length,
                 dropout_keep,
                 lamb_reg,
                 norm_eps,
                 norm_max,
                 kill_edit,
                 draw_edit,
                 swap_memory,
                 use_beam_decoder=False,
                 use_dropout=False,
                 no_insert_delete_attn=False,
                 enable_vae=True):
    batch_size = tf.shape(source_words)[0]

    # [batch]
    base_len = seq.length_pre_embedding(base_words)
    output_len = seq.length_pre_embedding(output_words)
    src_len = seq.length_pre_embedding(source_words)
    tgt_len = seq.length_pre_embedding(target_words)
    iw_len = seq.length_pre_embedding(insert_words)
    dw_len = seq.length_pre_embedding(delete_words)

    # variable of shape [vocab_size, embed_dim]
    embeddings = vocab.get_embeddings()

    # [batch x max_len x embed_dim]
    base_word_embeds = vocab.embed_tokens(base_words)
    src_word_embeds = vocab.embed_tokens(source_words)
    tgt_word_embeds = vocab.embed_tokens(target_words)
    insert_word_embeds = vocab.embed_tokens(insert_words)
    delete_word_embeds = vocab.embed_tokens(delete_words)

    sent_encoder = tf.make_template('sent_encoder',
                                    encoder.source_sent_encoder,
                                    hidden_dim=hidden_dim,
                                    num_layer=num_encoder_layers,
                                    swap_memory=swap_memory,
                                    use_dropout=use_dropout,
                                    dropout_keep=dropout_keep)

    # [batch x max_len x rnn_out_dim], [batch x rnn_out_dim]
    base_sent_hidden_states, base_sent_embed = sent_encoder(
        base_word_embeds, base_len)

    assert kill_edit == False and draw_edit == False

    # [batch x edit_dim]
    if kill_edit:
        edit_vector = tf.zeros(shape=(batch_size, edit_dim))
    else:
        if draw_edit:
            edit_vector = random_noise_encoder(batch_size, edit_dim, norm_max)
        else:
            edit_vector, wa_inserted, wa_deleted = attn_encoder(
                src_word_embeds,
                tgt_word_embeds,
                insert_word_embeds,
                delete_word_embeds,
                src_len,
                tgt_len,
                iw_len,
                dw_len,
                ctx_hidden_dim,
                ctx_hidden_layer,
                wa_hidden_dim,
                wa_hidden_layer,
                meve_hidden_dim,
                meve_hidden_layers,
                edit_dim,
                micro_edit_ev_dim,
                num_heads,
                lamb_reg,
                norm_eps,
                norm_max,
                sent_encoder,
                use_dropout=use_dropout,
                dropout_keep=dropout_keep,
                swap_memory=swap_memory,
                enable_vae=enable_vae)

    # [batch x agenda_dim]
    input_agenda = agn.linear(base_sent_embed, edit_vector, agenda_dim)

    train_dec_inp, train_dec_inp_len, \
    train_dec_out, train_dec_out_len = prepare_decoder_input_output(output_words, output_len, None)

    train_decoder = decoder.train_decoder(
        input_agenda,
        embeddings,
        train_dec_inp,
        base_sent_hidden_states,
        wa_inserted,
        wa_deleted,
        train_dec_inp_len,
        base_len,
        iw_len,
        dw_len,
        attn_dim,
        hidden_dim,
        num_decoder_layers,
        swap_memory,
        enable_dropout=use_dropout,
        dropout_keep=dropout_keep,
        no_insert_delete_attn=no_insert_delete_attn)

    if use_beam_decoder:
        infr_decoder = decoder.beam_eval_decoder(
            input_agenda,
            embeddings,
            vocab.get_token_id(vocab.START_TOKEN),
            vocab.get_token_id(vocab.STOP_TOKEN),
            base_sent_hidden_states,
            wa_inserted,
            wa_deleted,
            base_len,
            iw_len,
            dw_len,
            attn_dim,
            hidden_dim,
            num_decoder_layers,
            max_sent_length,
            beam_width,
            swap_memory,
            enable_dropout=use_dropout,
            dropout_keep=dropout_keep,
            no_insert_delete_attn=no_insert_delete_attn)
    else:
        infr_decoder = decoder.greedy_eval_decoder(
            input_agenda,
            embeddings,
            vocab.get_token_id(vocab.START_TOKEN),
            vocab.get_token_id(vocab.STOP_TOKEN),
            base_sent_hidden_states,
            wa_inserted,
            wa_deleted,
            base_len,
            iw_len,
            dw_len,
            attn_dim,
            hidden_dim,
            num_decoder_layers,
            max_sent_length,
            swap_memory,
            enable_dropout=use_dropout,
            dropout_keep=dropout_keep,
            no_insert_delete_attn=no_insert_delete_attn)

    return train_decoder, infr_decoder, train_dec_out, train_dec_out_len
def editor_train(base_words,
                 extended_base_words,
                 output_words,
                 extended_output_words,
                 source_words,
                 target_words,
                 insert_words,
                 delete_words,
                 oov,
                 vocab_size,
                 hidden_dim,
                 agenda_dim,
                 edit_dim,
                 micro_edit_ev_dim,
                 num_heads,
                 num_encoder_layers,
                 num_decoder_layers,
                 attn_dim,
                 beam_width,
                 ctx_hidden_dim,
                 ctx_hidden_layer,
                 wa_hidden_dim,
                 wa_hidden_layer,
                 meve_hidden_dim,
                 meve_hidden_layers,
                 max_sent_length,
                 dropout_keep,
                 lamb_reg,
                 norm_eps,
                 norm_max,
                 kill_edit,
                 draw_edit,
                 swap_memory,
                 use_beam_decoder=False,
                 use_dropout=False,
                 no_insert_delete_attn=False,
                 enable_vae=True):
    # [batch]
    base_len = seq.length_pre_embedding(base_words)
    output_len = seq.length_pre_embedding(extended_output_words)

    # variable of shape [vocab_size, embed_dim]
    embeddings = vocab.get_embeddings()

    # [batch x max_len x embed_dim]
    base_word_embeds = vocab.embed_tokens(base_words)

    # [batch x max_len x rnn_out_dim], [batch x rnn_out_dim]
    base_sent_hidden_states, base_sent_embed = encoder.source_sent_encoder(
        base_word_embeds,
        base_len,
        hidden_dim,
        num_encoder_layers,
        use_dropout=use_dropout,
        dropout_keep=dropout_keep,
        swap_memory=swap_memory)

    assert kill_edit == False and draw_edit == False

    # [batch x agenda_dim]
    base_agenda = linear(base_sent_embed, agenda_dim)

    train_dec_inp, train_dec_inp_len, \
    train_dec_out, train_dec_out_len = prepare_decoder_input_output(output_words, extended_output_words, output_len)

    train_dec_inp_extended = prepare_decoder_inputs(extended_output_words,
                                                    tf.cast(-1, tf.int64))

    train_decoder = decoder.train_decoder(
        base_agenda,
        embeddings,
        extended_base_words,
        oov,
        train_dec_inp,
        train_dec_inp_extended,
        base_sent_hidden_states,
        train_dec_inp_len,
        base_len,
        vocab_size,
        attn_dim,
        hidden_dim,
        num_decoder_layers,
        swap_memory,
        enable_dropout=use_dropout,
        dropout_keep=dropout_keep,
        no_insert_delete_attn=no_insert_delete_attn)

    if use_beam_decoder:
        infr_decoder = decoder.beam_eval_decoder(
            base_agenda,
            embeddings,
            extended_base_words,
            oov,
            vocab.get_token_id(vocab.START_TOKEN),
            vocab.get_token_id(vocab.STOP_TOKEN),
            base_sent_hidden_states,
            base_len,
            vocab_size,
            attn_dim,
            hidden_dim,
            num_decoder_layers,
            max_sent_length,
            beam_width,
            swap_memory,
            enable_dropout=use_dropout,
            dropout_keep=dropout_keep,
            no_insert_delete_attn=no_insert_delete_attn)
    else:
        infr_decoder = decoder.greedy_eval_decoder(
            base_agenda,
            embeddings,
            extended_base_words,
            oov,
            vocab.get_token_id(vocab.START_TOKEN),
            vocab.get_token_id(vocab.STOP_TOKEN),
            base_sent_hidden_states,
            base_len,
            vocab_size,
            attn_dim,
            hidden_dim,
            num_decoder_layers,
            max_sent_length,
            swap_memory,
            enable_dropout=use_dropout,
            dropout_keep=dropout_keep,
            no_insert_delete_attn=no_insert_delete_attn)

        add_decoder_attn_history_graph(infr_decoder)

    return train_decoder, infr_decoder, train_dec_out, train_dec_out_len
def decoder_outputs_to_edit_vector(decoder_output, temperature_starter,
                                   decay_rate, decay_steps, edit_dim,
                                   enc_hidden_dim, enc_num_layers,
                                   dense_layers, swap_memory):
    with tf.variable_scope(OPS_NAME):
        # [VOCAB x word_dim]
        embeddings = vocab.get_embeddings()

        # Extend embedding matrix to support oov tokens
        unk_id = vocab.get_token_id(vocab.UNKNOWN_TOKEN)
        unk_embed = tf.expand_dims(vocab.embed_tokens(unk_id), 0)
        unk_embeddings = tf.tile(unk_embed, [50, 1])

        # [VOCAB+50 x word_dim]
        embeddings_extended = tf.concat([embeddings, unk_embeddings], axis=0)

        global_step = tf.train.get_global_step()
        temperature = tf.train.exponential_decay(temperature_starter,
                                                 global_step,
                                                 decay_steps,
                                                 decay_rate,
                                                 name='temperature')
        tf.summary.scalar('temper', temperature, ['extra'])

        # [batch x max_len x VOCAB+50], softmax probabilities
        outputs = decoder.rnn_output(decoder_output)

        # substitute values less than 0 for numerical stability
        outputs = tf.where(tf.less_equal(outputs, 0),
                           tf.ones_like(outputs) * 1e-10, outputs)

        # convert softmax probabilities to one_hot vectors
        dist = tfd.RelaxedOneHotCategorical(temperature, probs=outputs)

        # [batch x max_len x VOCAB+50], one_hot
        outputs_one_hot = dist.sample()

        # [batch x max_len x word_dim], one_hot^T * embedding_matrix
        outputs_embed = tf.einsum("btv,vd-> btd", outputs_one_hot,
                                  embeddings_extended)

        # [batch]
        outputs_length = decoder.seq_length(decoder_output)

        # [batch x max_len x hidden], [batch x hidden]
        hidden_states, sentence_embedding = encoder.source_sent_encoder(
            outputs_embed,
            outputs_length,
            enc_hidden_dim,
            enc_num_layers,
            use_dropout=False,
            dropout_keep=1.0,
            swap_memory=swap_memory)

        h = sentence_embedding
        for l in dense_layers:
            h = tf.layers.dense(h,
                                l,
                                activation='relu',
                                name='hidden_%s' % (l))

        # [batch x edit_dim]
        edit_vector = tf.layers.dense(h,
                                      edit_dim,
                                      activation=None,
                                      name='linear')

    return edit_vector