示例#1
0
def prepare_decoder_input_output(tgt_words, tgt_len, vocab_table):
    """
    Args:
        tgt_words: tensor of word ids, [batch x max_len]
        tgt_len: vector of sentence lengths, [batch]
        vocab_table: instance of tf.vocab_lookup_table

    Returns:
        dec_input: tensor of word ids, [batch x max_len+1]
        dec_input_len: vector of sentence lengths, [batch]
        dec_output: tensor of word ids, [batch x max_len+1]
        dec_output_len: vector of sentence lengths, [batch]

    """
    start_token_id = vocab.get_token_id(vocab.START_TOKEN, vocab_table)
    stop_token_id = vocab.get_token_id(vocab.STOP_TOKEN, vocab_table)
    pad_token_id = vocab.get_token_id(vocab.PAD_TOKEN, vocab_table)

    dec_input = decoder.prepare_decoder_inputs(tgt_words, start_token_id)
    dec_input_len = seq.length_pre_embedding(dec_input)

    dec_output = decoder.prepare_decoder_output(tgt_words, tgt_len,
                                                stop_token_id, pad_token_id)
    dec_output_len = seq.length_pre_embedding(dec_output)

    return dec_input, dec_input_len, dec_output, dec_output_len
def test_wtf():
    with tf.Graph().as_default():
        V, embed_matrix = vocab.read_word_embeddings(
            Path('../data') / 'word_vectors' / 'glove.6B.300d_yelp.txt',
            300,
            10000
        )

        table = vocab.create_vocab_lookup_tables(V)
        vocab_s2i = table[vocab.STR_TO_INT]
        vocab_i2s = table[vocab.INT_TO_STR]

        dataset = input_fn('../data/yelp_dataset_large_split/train.tsv', table, 64, 1)
        iter = dataset.make_initializable_iterator()

        (src, tgt, iw, dw), _ = iter.get_next()
        src_len = length_pre_embedding(src)
        tgt_len = length_pre_embedding(tgt)
        iw_len = length_pre_embedding(iw)
        dw_len = length_pre_embedding(dw)

        dec_inputs = decoder.prepare_decoder_inputs(tgt, vocab.get_token_id(vocab.START_TOKEN, vocab_s2i))

        dec_output = decoder.prepare_decoder_output(tgt, tgt_len, vocab.get_token_id(vocab.STOP_TOKEN, vocab_s2i),
                                                    vocab.get_token_id(vocab.PAD_TOKEN, vocab_s2i))

        t_src = vocab_i2s.lookup(src)
        t_tgt = vocab_i2s.lookup(tgt)
        t_iw = vocab_i2s.lookup(iw)
        t_dw = vocab_i2s.lookup(dw)

        t_do = vocab_i2s.lookup(dec_output)
        t_di = vocab_i2s.lookup(dec_inputs)

        with tf.Session() as sess:
            sess.run([tf.global_variables_initializer(), tf.local_variables_initializer(), tf.tables_initializer()])
            sess.run(iter.initializer)

            while True:
                try:
                    # src, tgt, iw, dw = sess.run([src, tgt, iw, dw])
                    ts, tt, tiw, tdw, tdo, tdi = sess.run([t_src, t_tgt, t_iw, t_dw, t_do, t_di])
                except:
                    break
示例#3
0
    def fn(orig_ids):
        orig_ids = tf.cast(orig_ids, tf.int64)

        in_vocab_ids = tf.where(tf.less(orig_ids, vocab_size), orig_ids,
                                tf.ones_like(orig_ids) * vocab.OOV_TOKEN_ID)
        embeds = vocab.embed_tokens(in_vocab_ids)

        last_ids = tf.where(
            tf.equal(orig_ids, vocab.get_token_id(vocab.START_TOKEN)),
            tf.ones_like(orig_ids) * -1, orig_ids)
        last_ids = tf.cast(tf.expand_dims(last_ids, 2), tf.float32)

        cell_input = tf.concat([embeds, last_ids], axis=2)

        return cell_input
示例#4
0
def prepare_output_embed(
    decoder_output,
    temperature_starter,
    decay_rate,
    decay_steps,
):
    # [VOCAB x word_dim]
    embeddings = vocab.get_embeddings()

    # Extend embedding matrix to support oov tokens
    unk_id = vocab.get_token_id(vocab.UNKNOWN_TOKEN)
    unk_embed = tf.expand_dims(vocab.embed_tokens(unk_id), 0)
    unk_embeddings = tf.tile(unk_embed, [50, 1])

    # [VOCAB+50 x word_dim]
    embeddings_extended = tf.concat([embeddings, unk_embeddings], axis=0)

    global_step = tf.train.get_global_step()
    temperature = tf.train.exponential_decay(temperature_starter,
                                             global_step,
                                             decay_steps,
                                             decay_rate,
                                             name='temperature')
    tf.summary.scalar('temper', temperature, ['extra'])

    # [batch x max_len x VOCAB+50], softmax probabilities
    outputs = decoder.rnn_output(decoder_output)

    # substitute values less than 0 for numerical stability
    outputs = tf.where(tf.less_equal(outputs, 0),
                       tf.ones_like(outputs) * 1e-10, outputs)

    # convert softmax probabilities to one_hot vectors
    dist = tfd.RelaxedOneHotCategorical(temperature, probs=outputs)

    # [batch x max_len x VOCAB+50], one_hot
    outputs_one_hot = dist.sample()

    # [batch x max_len x word_dim], one_hot^T * embedding_matrix
    outputs_embed = tf.einsum("btv,vd-> btd", outputs_one_hot,
                              embeddings_extended)

    return outputs_embed
示例#5
0
def editor_train(base_words,
                 source_words,
                 target_words,
                 insert_words,
                 delete_words,
                 embed_matrix,
                 vocab_table,
                 hidden_dim,
                 agenda_dim,
                 edit_dim,
                 num_encoder_layers,
                 num_decoder_layers,
                 attn_dim,
                 beam_width,
                 max_sent_length,
                 dropout_keep,
                 lamb_reg,
                 norm_eps,
                 norm_max,
                 kill_edit,
                 draw_edit,
                 swap_memory,
                 use_beam_decoder=False,
                 use_dropout=False,
                 no_insert_delete_attn=False,
                 enable_vae=True):
    batch_size = tf.shape(source_words)[0]

    # [batch]
    base_len = seq.length_pre_embedding(base_words)
    src_len = seq.length_pre_embedding(source_words)
    tgt_len = seq.length_pre_embedding(target_words)
    iw_len = seq.length_pre_embedding(insert_words)
    dw_len = seq.length_pre_embedding(delete_words)

    # variable of shape [vocab_size, embed_dim]
    embeddings = vocab.get_embeddings()

    # [batch x max_len x embed_dim]
    base_word_embeds = vocab.embed_tokens(base_words)
    # src_word_embeds = vocab.embed_tokens(source_words)
    # tgt_word_embeds = vocab.embed_tokens(target_words)
    insert_word_embeds = vocab.embed_tokens(insert_words)
    delete_word_embeds = vocab.embed_tokens(delete_words)

    # [batch x max_len x rnn_out_dim], [batch x rnn_out_dim]
    base_sent_hidden_states, base_sent_embed = encoder.source_sent_encoder(
        base_word_embeds,
        base_len,
        hidden_dim,
        num_encoder_layers,
        use_dropout=use_dropout,
        dropout_keep=dropout_keep,
        swap_memory=swap_memory)

    # [batch x edit_dim]
    if kill_edit:
        edit_vector = tf.zeros(shape=(batch_size, edit_dim))
    else:
        if draw_edit:
            edit_vector = edit_encoder.random_noise_encoder(
                batch_size, edit_dim, norm_max)
        else:
            edit_vector = edit_encoder.accumulator_encoder(
                insert_word_embeds,
                delete_word_embeds,
                iw_len,
                dw_len,
                edit_dim,
                lamb_reg,
                norm_eps,
                norm_max,
                dropout_keep,
                enable_vae=enable_vae)

    # [batch x agenda_dim]
    input_agenda = agn.linear(base_sent_embed, edit_vector, agenda_dim)

    train_dec_inp, train_dec_inp_len, \
    train_dec_out, train_dec_out_len = prepare_decoder_input_output(target_words, tgt_len, vocab_table)

    train_decoder = decoder.train_decoder(
        input_agenda,
        embeddings,
        train_dec_inp,
        base_sent_hidden_states,
        insert_word_embeds,
        delete_word_embeds,
        train_dec_inp_len,
        base_len,
        iw_len,
        dw_len,
        attn_dim,
        hidden_dim,
        num_decoder_layers,
        swap_memory,
        enable_dropout=use_dropout,
        dropout_keep=dropout_keep,
        no_insert_delete_attn=no_insert_delete_attn)

    if use_beam_decoder:
        infr_decoder = decoder.beam_eval_decoder(
            input_agenda,
            embeddings,
            vocab.get_token_id(vocab.START_TOKEN, vocab_table),
            vocab.get_token_id(vocab.STOP_TOKEN, vocab_table),
            base_sent_hidden_states,
            insert_word_embeds,
            delete_word_embeds,
            base_len,
            iw_len,
            dw_len,
            attn_dim,
            hidden_dim,
            num_decoder_layers,
            max_sent_length,
            beam_width,
            swap_memory,
            enable_dropout=use_dropout,
            dropout_keep=dropout_keep,
            no_insert_delete_attn=no_insert_delete_attn)
    else:
        infr_decoder = decoder.greedy_eval_decoder(
            input_agenda,
            embeddings,
            vocab.get_token_id(vocab.START_TOKEN, vocab_table),
            vocab.get_token_id(vocab.STOP_TOKEN, vocab_table),
            base_sent_hidden_states,
            insert_word_embeds,
            delete_word_embeds,
            base_len,
            iw_len,
            dw_len,
            attn_dim,
            hidden_dim,
            num_decoder_layers,
            max_sent_length,
            swap_memory,
            enable_dropout=use_dropout,
            dropout_keep=dropout_keep,
            no_insert_delete_attn=no_insert_delete_attn)

    return train_decoder, infr_decoder, train_dec_out, train_dec_out_len
示例#6
0
def editor_train(base_words,
                 extended_base_words,
                 output_words,
                 extended_output_words,
                 source_words,
                 target_words,
                 insert_words,
                 delete_words,
                 oov,
                 vocab_size,
                 hidden_dim,
                 agenda_dim,
                 edit_dim,
                 micro_edit_ev_dim,
                 num_heads,
                 num_encoder_layers,
                 num_decoder_layers,
                 attn_dim,
                 beam_width,
                 ctx_hidden_dim,
                 ctx_hidden_layer,
                 wa_hidden_dim,
                 wa_hidden_layer,
                 meve_hidden_dim,
                 meve_hidden_layers,
                 recons_dense_layers,
                 max_sent_length,
                 dropout_keep,
                 lamb_reg,
                 norm_eps,
                 norm_max,
                 kill_edit,
                 draw_edit,
                 swap_memory,
                 use_beam_decoder=False,
                 use_dropout=False,
                 no_insert_delete_attn=False,
                 enable_vae=True):
    batch_size = tf.shape(source_words)[0]

    # [batch]
    base_len = seq.length_pre_embedding(base_words)
    output_len = seq.length_pre_embedding(extended_output_words)
    src_len = seq.length_pre_embedding(source_words)
    tgt_len = seq.length_pre_embedding(target_words)
    iw_len = seq.length_pre_embedding(insert_words)
    dw_len = seq.length_pre_embedding(delete_words)

    # variable of shape [vocab_size, embed_dim]
    embeddings = vocab.get_embeddings()

    # [batch x max_len x embed_dim]
    base_word_embeds = vocab.embed_tokens(base_words)
    output_word_embeds = vocab.embed_tokens(output_words)
    # src_word_embeds = vocab.embed_tokens(source_words)
    # tgt_word_embeds = vocab.embed_tokens(target_words)
    insert_word_embeds = vocab.embed_tokens(insert_words)
    delete_word_embeds = vocab.embed_tokens(delete_words)

    # [batch x max_len x rnn_out_dim], [batch x rnn_out_dim]
    base_sent_hidden_states, base_sent_embed = encoder.source_sent_encoder(
        base_word_embeds,
        base_len,
        hidden_dim,
        num_encoder_layers,
        use_dropout=use_dropout,
        dropout_keep=dropout_keep,
        swap_memory=swap_memory)

    assert kill_edit == False and draw_edit == False

    # [batch x edit_dim]
    if kill_edit:
        edit_vector = tf.zeros(shape=(batch_size, edit_dim))
    else:
        if draw_edit:
            edit_vector = random_noise_encoder(batch_size, edit_dim, norm_max)
        else:
            edit_vector = accumulator_encoder(insert_word_embeds,
                                              delete_word_embeds,
                                              iw_len,
                                              dw_len,
                                              edit_dim,
                                              lamb_reg,
                                              norm_eps,
                                              norm_max,
                                              dropout_keep,
                                              enable_vae=enable_vae)

            wa_inserted, wa_deleted = (tf.constant([[0]]), tf.constant(
                [[0]])), (tf.constant([[0]]), tf.constant([[0]]))

    # [batch x agenda_dim]
    base_agenda = agn.linear(base_sent_embed, edit_vector, agenda_dim)

    train_dec_inp, train_dec_inp_len, \
    train_dec_out, train_dec_out_len = prepare_decoder_input_output(output_words, extended_output_words, output_len)

    train_dec_inp_extended = prepare_decoder_inputs(extended_output_words,
                                                    tf.cast(-1, tf.int64))

    train_decoder = decoder.train_decoder(
        base_agenda,
        embeddings,
        extended_base_words,
        oov,
        train_dec_inp,
        train_dec_inp_extended,
        base_sent_hidden_states,
        wa_inserted,
        wa_deleted,
        train_dec_inp_len,
        base_len,
        src_len,
        tgt_len,
        vocab_size,
        attn_dim,
        hidden_dim,
        num_decoder_layers,
        swap_memory,
        enable_dropout=use_dropout,
        dropout_keep=dropout_keep,
        no_insert_delete_attn=no_insert_delete_attn)

    if use_beam_decoder:
        infr_decoder = decoder.beam_eval_decoder(
            base_agenda,
            embeddings,
            extended_base_words,
            oov,
            vocab.get_token_id(vocab.START_TOKEN),
            vocab.get_token_id(vocab.STOP_TOKEN),
            base_sent_hidden_states,
            wa_inserted,
            wa_deleted,
            base_len,
            src_len,
            tgt_len,
            vocab_size,
            attn_dim,
            hidden_dim,
            num_decoder_layers,
            max_sent_length,
            beam_width,
            swap_memory,
            enable_dropout=use_dropout,
            dropout_keep=dropout_keep,
            no_insert_delete_attn=no_insert_delete_attn)
    else:
        infr_decoder = decoder.greedy_eval_decoder(
            base_agenda,
            embeddings,
            extended_base_words,
            oov,
            vocab.get_token_id(vocab.START_TOKEN),
            vocab.get_token_id(vocab.STOP_TOKEN),
            base_sent_hidden_states,
            wa_inserted,
            wa_deleted,
            base_len,
            src_len,
            tgt_len,
            vocab_size,
            attn_dim,
            hidden_dim,
            num_decoder_layers,
            max_sent_length,
            swap_memory,
            enable_dropout=use_dropout,
            dropout_keep=dropout_keep,
            no_insert_delete_attn=no_insert_delete_attn)

    edit_vector_recons = output_words_to_edit_vector(
        output_word_embeds, output_len, edit_dim, ctx_hidden_dim,
        ctx_hidden_layer, recons_dense_layers, swap_memory)
    optimizer.add_reconst_loss(edit_vector, edit_vector_recons)

    return train_decoder, infr_decoder, train_dec_out, train_dec_out_len
示例#7
0
def editor_train(base_words,
                 output_words,
                 source_words,
                 target_words,
                 insert_words,
                 delete_words,
                 hidden_dim,
                 agenda_dim,
                 edit_dim,
                 micro_edit_ev_dim,
                 num_heads,
                 num_encoder_layers,
                 num_decoder_layers,
                 attn_dim,
                 beam_width,
                 ctx_hidden_dim,
                 ctx_hidden_layer,
                 wa_hidden_dim,
                 wa_hidden_layer,
                 meve_hidden_dim,
                 meve_hidden_layers,
                 max_sent_length,
                 dropout_keep,
                 lamb_reg,
                 norm_eps,
                 norm_max,
                 kill_edit,
                 draw_edit,
                 swap_memory,
                 use_beam_decoder=False,
                 use_dropout=False,
                 no_insert_delete_attn=False,
                 enable_vae=True):
    batch_size = tf.shape(source_words)[0]

    # [batch]
    base_len = seq.length_pre_embedding(base_words)
    output_len = seq.length_pre_embedding(output_words)
    src_len = seq.length_pre_embedding(source_words)
    tgt_len = seq.length_pre_embedding(target_words)
    iw_len = seq.length_pre_embedding(insert_words)
    dw_len = seq.length_pre_embedding(delete_words)

    # variable of shape [vocab_size, embed_dim]
    embeddings = vocab.get_embeddings()

    # [batch x max_len x embed_dim]
    base_word_embeds = vocab.embed_tokens(base_words)
    src_word_embeds = vocab.embed_tokens(source_words)
    tgt_word_embeds = vocab.embed_tokens(target_words)
    insert_word_embeds = vocab.embed_tokens(insert_words)
    delete_word_embeds = vocab.embed_tokens(delete_words)

    sent_encoder = tf.make_template('sent_encoder',
                                    encoder.source_sent_encoder,
                                    hidden_dim=hidden_dim,
                                    num_layer=num_encoder_layers,
                                    swap_memory=swap_memory,
                                    use_dropout=use_dropout,
                                    dropout_keep=dropout_keep)

    # [batch x max_len x rnn_out_dim], [batch x rnn_out_dim]
    base_sent_hidden_states, base_sent_embed = sent_encoder(
        base_word_embeds, base_len)

    assert kill_edit == False and draw_edit == False

    # [batch x edit_dim]
    if kill_edit:
        edit_vector = tf.zeros(shape=(batch_size, edit_dim))
    else:
        if draw_edit:
            edit_vector = random_noise_encoder(batch_size, edit_dim, norm_max)
        else:
            edit_vector, wa_inserted, wa_deleted = attn_encoder(
                src_word_embeds,
                tgt_word_embeds,
                insert_word_embeds,
                delete_word_embeds,
                src_len,
                tgt_len,
                iw_len,
                dw_len,
                ctx_hidden_dim,
                ctx_hidden_layer,
                wa_hidden_dim,
                wa_hidden_layer,
                meve_hidden_dim,
                meve_hidden_layers,
                edit_dim,
                micro_edit_ev_dim,
                num_heads,
                lamb_reg,
                norm_eps,
                norm_max,
                sent_encoder,
                use_dropout=use_dropout,
                dropout_keep=dropout_keep,
                swap_memory=swap_memory,
                enable_vae=enable_vae)

    # [batch x agenda_dim]
    input_agenda = agn.linear(base_sent_embed, edit_vector, agenda_dim)

    train_dec_inp, train_dec_inp_len, \
    train_dec_out, train_dec_out_len = prepare_decoder_input_output(output_words, output_len, None)

    train_decoder = decoder.train_decoder(
        input_agenda,
        embeddings,
        train_dec_inp,
        base_sent_hidden_states,
        wa_inserted,
        wa_deleted,
        train_dec_inp_len,
        base_len,
        iw_len,
        dw_len,
        attn_dim,
        hidden_dim,
        num_decoder_layers,
        swap_memory,
        enable_dropout=use_dropout,
        dropout_keep=dropout_keep,
        no_insert_delete_attn=no_insert_delete_attn)

    if use_beam_decoder:
        infr_decoder = decoder.beam_eval_decoder(
            input_agenda,
            embeddings,
            vocab.get_token_id(vocab.START_TOKEN),
            vocab.get_token_id(vocab.STOP_TOKEN),
            base_sent_hidden_states,
            wa_inserted,
            wa_deleted,
            base_len,
            iw_len,
            dw_len,
            attn_dim,
            hidden_dim,
            num_decoder_layers,
            max_sent_length,
            beam_width,
            swap_memory,
            enable_dropout=use_dropout,
            dropout_keep=dropout_keep,
            no_insert_delete_attn=no_insert_delete_attn)
    else:
        infr_decoder = decoder.greedy_eval_decoder(
            input_agenda,
            embeddings,
            vocab.get_token_id(vocab.START_TOKEN),
            vocab.get_token_id(vocab.STOP_TOKEN),
            base_sent_hidden_states,
            wa_inserted,
            wa_deleted,
            base_len,
            iw_len,
            dw_len,
            attn_dim,
            hidden_dim,
            num_decoder_layers,
            max_sent_length,
            swap_memory,
            enable_dropout=use_dropout,
            dropout_keep=dropout_keep,
            no_insert_delete_attn=no_insert_delete_attn)

    return train_decoder, infr_decoder, train_dec_out, train_dec_out_len
示例#8
0
def editor_train(base_words,
                 extended_base_words,
                 output_words,
                 extended_output_words,
                 source_words,
                 target_words,
                 insert_words,
                 delete_words,
                 oov,
                 vocab_size,
                 hidden_dim,
                 agenda_dim,
                 edit_dim,
                 micro_edit_ev_dim,
                 num_heads,
                 num_encoder_layers,
                 num_decoder_layers,
                 attn_dim,
                 beam_width,
                 ctx_hidden_dim,
                 ctx_hidden_layer,
                 wa_hidden_dim,
                 wa_hidden_layer,
                 meve_hidden_dim,
                 meve_hidden_layers,
                 max_sent_length,
                 dropout_keep,
                 lamb_reg,
                 norm_eps,
                 norm_max,
                 kill_edit,
                 draw_edit,
                 swap_memory,
                 use_beam_decoder=False,
                 use_dropout=False,
                 no_insert_delete_attn=False,
                 enable_vae=True):
    # [batch]
    base_len = seq.length_pre_embedding(base_words)
    output_len = seq.length_pre_embedding(extended_output_words)

    # variable of shape [vocab_size, embed_dim]
    embeddings = vocab.get_embeddings()

    # [batch x max_len x embed_dim]
    base_word_embeds = vocab.embed_tokens(base_words)

    # [batch x max_len x rnn_out_dim], [batch x rnn_out_dim]
    base_sent_hidden_states, base_sent_embed = encoder.source_sent_encoder(
        base_word_embeds,
        base_len,
        hidden_dim,
        num_encoder_layers,
        use_dropout=use_dropout,
        dropout_keep=dropout_keep,
        swap_memory=swap_memory)

    assert kill_edit == False and draw_edit == False

    # [batch x agenda_dim]
    base_agenda = linear(base_sent_embed, agenda_dim)

    train_dec_inp, train_dec_inp_len, \
    train_dec_out, train_dec_out_len = prepare_decoder_input_output(output_words, extended_output_words, output_len)

    train_dec_inp_extended = prepare_decoder_inputs(extended_output_words,
                                                    tf.cast(-1, tf.int64))

    train_decoder = decoder.train_decoder(
        base_agenda,
        embeddings,
        extended_base_words,
        oov,
        train_dec_inp,
        train_dec_inp_extended,
        base_sent_hidden_states,
        train_dec_inp_len,
        base_len,
        vocab_size,
        attn_dim,
        hidden_dim,
        num_decoder_layers,
        swap_memory,
        enable_dropout=use_dropout,
        dropout_keep=dropout_keep,
        no_insert_delete_attn=no_insert_delete_attn)

    if use_beam_decoder:
        infr_decoder = decoder.beam_eval_decoder(
            base_agenda,
            embeddings,
            extended_base_words,
            oov,
            vocab.get_token_id(vocab.START_TOKEN),
            vocab.get_token_id(vocab.STOP_TOKEN),
            base_sent_hidden_states,
            base_len,
            vocab_size,
            attn_dim,
            hidden_dim,
            num_decoder_layers,
            max_sent_length,
            beam_width,
            swap_memory,
            enable_dropout=use_dropout,
            dropout_keep=dropout_keep,
            no_insert_delete_attn=no_insert_delete_attn)
    else:
        infr_decoder = decoder.greedy_eval_decoder(
            base_agenda,
            embeddings,
            extended_base_words,
            oov,
            vocab.get_token_id(vocab.START_TOKEN),
            vocab.get_token_id(vocab.STOP_TOKEN),
            base_sent_hidden_states,
            base_len,
            vocab_size,
            attn_dim,
            hidden_dim,
            num_decoder_layers,
            max_sent_length,
            swap_memory,
            enable_dropout=use_dropout,
            dropout_keep=dropout_keep,
            no_insert_delete_attn=no_insert_delete_attn)

        add_decoder_attn_history_graph(infr_decoder)

    return train_decoder, infr_decoder, train_dec_out, train_dec_out_len
def decoder_outputs_to_edit_vector(decoder_output, temperature_starter,
                                   decay_rate, decay_steps, edit_dim,
                                   enc_hidden_dim, enc_num_layers,
                                   dense_layers, swap_memory):
    with tf.variable_scope(OPS_NAME):
        # [VOCAB x word_dim]
        embeddings = vocab.get_embeddings()

        # Extend embedding matrix to support oov tokens
        unk_id = vocab.get_token_id(vocab.UNKNOWN_TOKEN)
        unk_embed = tf.expand_dims(vocab.embed_tokens(unk_id), 0)
        unk_embeddings = tf.tile(unk_embed, [50, 1])

        # [VOCAB+50 x word_dim]
        embeddings_extended = tf.concat([embeddings, unk_embeddings], axis=0)

        global_step = tf.train.get_global_step()
        temperature = tf.train.exponential_decay(temperature_starter,
                                                 global_step,
                                                 decay_steps,
                                                 decay_rate,
                                                 name='temperature')
        tf.summary.scalar('temper', temperature, ['extra'])

        # [batch x max_len x VOCAB+50], softmax probabilities
        outputs = decoder.rnn_output(decoder_output)

        # substitute values less than 0 for numerical stability
        outputs = tf.where(tf.less_equal(outputs, 0),
                           tf.ones_like(outputs) * 1e-10, outputs)

        # convert softmax probabilities to one_hot vectors
        dist = tfd.RelaxedOneHotCategorical(temperature, probs=outputs)

        # [batch x max_len x VOCAB+50], one_hot
        outputs_one_hot = dist.sample()

        # [batch x max_len x word_dim], one_hot^T * embedding_matrix
        outputs_embed = tf.einsum("btv,vd-> btd", outputs_one_hot,
                                  embeddings_extended)

        # [batch]
        outputs_length = decoder.seq_length(decoder_output)

        # [batch x max_len x hidden], [batch x hidden]
        hidden_states, sentence_embedding = encoder.source_sent_encoder(
            outputs_embed,
            outputs_length,
            enc_hidden_dim,
            enc_num_layers,
            use_dropout=False,
            dropout_keep=1.0,
            swap_memory=swap_memory)

        h = sentence_embedding
        for l in dense_layers:
            h = tf.layers.dense(h,
                                l,
                                activation='relu',
                                name='hidden_%s' % (l))

        # [batch x edit_dim]
        edit_vector = tf.layers.dense(h,
                                      edit_dim,
                                      activation=None,
                                      name='linear')

    return edit_vector