def residual_decoder(agenda, dec_inputs, dec_input_lengths, hidden_dim, num_layer, swap_memory, enable_dropout=False, dropout_keep=1., name=None): with tf.variable_scope(name, 'residual_decoder', []): batch_size = tf.shape(dec_inputs)[0] embeddings = vocab.get_embeddings() # Concatenate agenda [y_hat;base_input_embed] with decoder inputs # [batch x max_len x word_dim] dec_inputs = tf.nn.embedding_lookup(embeddings, dec_inputs) max_len = tf.shape(dec_inputs)[1] # [batch x 1 x agenda_dim] agenda = tf.expand_dims(agenda, axis=1) # [batch x max_len x agenda_dim] agenda = tf.tile(agenda, [1, max_len, 1]) # [batch x max_len x word_dim+agenda_dim] dec_inputs = tf.concat([dec_inputs, agenda], axis=2) helper = seq2seq.TrainingHelper(dec_inputs, dec_input_lengths, name='train_helper') cell = tf_rnn.MultiRNNCell([ create_rnn_layer(i, hidden_dim // 2, enable_dropout, dropout_keep) for i in range(num_layer) ]) zero_states = create_trainable_initial_states(batch_size, cell) output_layer = DecoderOutputLayer(embeddings) decoder = seq2seq.BasicDecoder(cell, helper, zero_states, output_layer) outputs, state, length = seq2seq.dynamic_decode( decoder, swap_memory=swap_memory) return outputs, state, length
def prepare_output_embed( decoder_output, temperature_starter, decay_rate, decay_steps, ): # [VOCAB x word_dim] embeddings = vocab.get_embeddings() # Extend embedding matrix to support oov tokens unk_id = vocab.get_token_id(vocab.UNKNOWN_TOKEN) unk_embed = tf.expand_dims(vocab.embed_tokens(unk_id), 0) unk_embeddings = tf.tile(unk_embed, [50, 1]) # [VOCAB+50 x word_dim] embeddings_extended = tf.concat([embeddings, unk_embeddings], axis=0) global_step = tf.train.get_global_step() temperature = tf.train.exponential_decay(temperature_starter, global_step, decay_steps, decay_rate, name='temperature') tf.summary.scalar('temper', temperature, ['extra']) # [batch x max_len x VOCAB+50], softmax probabilities outputs = decoder.rnn_output(decoder_output) # substitute values less than 0 for numerical stability outputs = tf.where(tf.less_equal(outputs, 0), tf.ones_like(outputs) * 1e-10, outputs) # convert softmax probabilities to one_hot vectors dist = tfd.RelaxedOneHotCategorical(temperature, probs=outputs) # [batch x max_len x VOCAB+50], one_hot outputs_one_hot = dist.sample() # [batch x max_len x word_dim], one_hot^T * embedding_matrix outputs_embed = tf.einsum("btv,vd-> btd", outputs_one_hot, embeddings_extended) return outputs_embed
def create_decoder_cell(agenda, extended_base_words, oov, base_sent_hiddens, mev_st, mev_ts, base_length, iw_length, dw_length, vocab_size, attn_dim, hidden_dim, num_layer, enable_alignment_history=False, enable_dropout=False, dropout_keep=1., no_insert_delete_attn=False, beam_width=None): base_attn = seq2seq.BahdanauAttention(attn_dim, base_sent_hiddens, base_length, name='base_attn') cnx_src, micro_evs_st = mev_st mev_st_attn = seq2seq.BahdanauAttention(attn_dim, cnx_src, iw_length, name='mev_st_attn') mev_st_attn._values = micro_evs_st attns = [base_attn, mev_st_attn] if not no_insert_delete_attn: cnx_tgt, micro_evs_ts = mev_ts mev_ts_attn = seq2seq.BahdanauAttention(attn_dim, cnx_tgt, dw_length, name='mev_ts_attn') mev_ts_attn._values = micro_evs_ts attns += [mev_ts_attn] is_training = tf.get_collection('is_training')[0] enable_alignment_history = not is_training bottom_cell = tf_rnn.LSTMCell(hidden_dim, name='bottom_cell') bottom_attn_cell = seq2seq.AttentionWrapper( bottom_cell, tuple(attns), alignment_history=enable_alignment_history, output_attention=False, name='att_bottom_cell') all_cells = [bottom_attn_cell] num_layer -= 1 for i in range(num_layer): cell = tf_rnn.LSTMCell(hidden_dim, name='layer_%s' % (i + 1)) if enable_dropout and dropout_keep < 1.: cell = tf_rnn.DropoutWrapper(cell, output_keep_prob=dropout_keep) all_cells.append(cell) decoder_cell = AttentionAugmentRNNCell(all_cells) decoder_cell.set_agenda(agenda) decoder_cell.set_source_attn_index(0) output_layer = DecoderOutputLayer(vocab.get_embeddings()) pg_cell = PointerGeneratorWrapper(decoder_cell, extended_base_words, 50, output_layer, vocab_size, decoder_cell.get_source_attention, name='PointerGeneratorWrapper') if beam_width: true_batch_size = tf.cast( tf.shape(base_sent_hiddens)[0] / beam_width, tf.int32) else: true_batch_size = tf.shape(base_sent_hiddens)[0] zero_state = create_trainable_zero_state(decoder_cell, true_batch_size, beam_width=beam_width) return pg_cell, zero_state
def editor_train(base_words, source_words, target_words, insert_words, delete_words, embed_matrix, vocab_table, hidden_dim, agenda_dim, edit_dim, num_encoder_layers, num_decoder_layers, attn_dim, beam_width, max_sent_length, dropout_keep, lamb_reg, norm_eps, norm_max, kill_edit, draw_edit, swap_memory, use_beam_decoder=False, use_dropout=False, no_insert_delete_attn=False, enable_vae=True): batch_size = tf.shape(source_words)[0] # [batch] base_len = seq.length_pre_embedding(base_words) src_len = seq.length_pre_embedding(source_words) tgt_len = seq.length_pre_embedding(target_words) iw_len = seq.length_pre_embedding(insert_words) dw_len = seq.length_pre_embedding(delete_words) # variable of shape [vocab_size, embed_dim] embeddings = vocab.get_embeddings() # [batch x max_len x embed_dim] base_word_embeds = vocab.embed_tokens(base_words) # src_word_embeds = vocab.embed_tokens(source_words) # tgt_word_embeds = vocab.embed_tokens(target_words) insert_word_embeds = vocab.embed_tokens(insert_words) delete_word_embeds = vocab.embed_tokens(delete_words) # [batch x max_len x rnn_out_dim], [batch x rnn_out_dim] base_sent_hidden_states, base_sent_embed = encoder.source_sent_encoder( base_word_embeds, base_len, hidden_dim, num_encoder_layers, use_dropout=use_dropout, dropout_keep=dropout_keep, swap_memory=swap_memory) # [batch x edit_dim] if kill_edit: edit_vector = tf.zeros(shape=(batch_size, edit_dim)) else: if draw_edit: edit_vector = edit_encoder.random_noise_encoder( batch_size, edit_dim, norm_max) else: edit_vector = edit_encoder.accumulator_encoder( insert_word_embeds, delete_word_embeds, iw_len, dw_len, edit_dim, lamb_reg, norm_eps, norm_max, dropout_keep, enable_vae=enable_vae) # [batch x agenda_dim] input_agenda = agn.linear(base_sent_embed, edit_vector, agenda_dim) train_dec_inp, train_dec_inp_len, \ train_dec_out, train_dec_out_len = prepare_decoder_input_output(target_words, tgt_len, vocab_table) train_decoder = decoder.train_decoder( input_agenda, embeddings, train_dec_inp, base_sent_hidden_states, insert_word_embeds, delete_word_embeds, train_dec_inp_len, base_len, iw_len, dw_len, attn_dim, hidden_dim, num_decoder_layers, swap_memory, enable_dropout=use_dropout, dropout_keep=dropout_keep, no_insert_delete_attn=no_insert_delete_attn) if use_beam_decoder: infr_decoder = decoder.beam_eval_decoder( input_agenda, embeddings, vocab.get_token_id(vocab.START_TOKEN, vocab_table), vocab.get_token_id(vocab.STOP_TOKEN, vocab_table), base_sent_hidden_states, insert_word_embeds, delete_word_embeds, base_len, iw_len, dw_len, attn_dim, hidden_dim, num_decoder_layers, max_sent_length, beam_width, swap_memory, enable_dropout=use_dropout, dropout_keep=dropout_keep, no_insert_delete_attn=no_insert_delete_attn) else: infr_decoder = decoder.greedy_eval_decoder( input_agenda, embeddings, vocab.get_token_id(vocab.START_TOKEN, vocab_table), vocab.get_token_id(vocab.STOP_TOKEN, vocab_table), base_sent_hidden_states, insert_word_embeds, delete_word_embeds, base_len, iw_len, dw_len, attn_dim, hidden_dim, num_decoder_layers, max_sent_length, swap_memory, enable_dropout=use_dropout, dropout_keep=dropout_keep, no_insert_delete_attn=no_insert_delete_attn) return train_decoder, infr_decoder, train_dec_out, train_dec_out_len
def editor_train(base_words, extended_base_words, output_words, extended_output_words, source_words, target_words, insert_words, delete_words, oov, vocab_size, hidden_dim, agenda_dim, edit_dim, micro_edit_ev_dim, num_heads, num_encoder_layers, num_decoder_layers, attn_dim, beam_width, ctx_hidden_dim, ctx_hidden_layer, wa_hidden_dim, wa_hidden_layer, meve_hidden_dim, meve_hidden_layers, recons_dense_layers, max_sent_length, dropout_keep, lamb_reg, norm_eps, norm_max, kill_edit, draw_edit, swap_memory, use_beam_decoder=False, use_dropout=False, no_insert_delete_attn=False, enable_vae=True): batch_size = tf.shape(source_words)[0] # [batch] base_len = seq.length_pre_embedding(base_words) output_len = seq.length_pre_embedding(extended_output_words) src_len = seq.length_pre_embedding(source_words) tgt_len = seq.length_pre_embedding(target_words) iw_len = seq.length_pre_embedding(insert_words) dw_len = seq.length_pre_embedding(delete_words) # variable of shape [vocab_size, embed_dim] embeddings = vocab.get_embeddings() # [batch x max_len x embed_dim] base_word_embeds = vocab.embed_tokens(base_words) output_word_embeds = vocab.embed_tokens(output_words) # src_word_embeds = vocab.embed_tokens(source_words) # tgt_word_embeds = vocab.embed_tokens(target_words) insert_word_embeds = vocab.embed_tokens(insert_words) delete_word_embeds = vocab.embed_tokens(delete_words) # [batch x max_len x rnn_out_dim], [batch x rnn_out_dim] base_sent_hidden_states, base_sent_embed = encoder.source_sent_encoder( base_word_embeds, base_len, hidden_dim, num_encoder_layers, use_dropout=use_dropout, dropout_keep=dropout_keep, swap_memory=swap_memory) assert kill_edit == False and draw_edit == False # [batch x edit_dim] if kill_edit: edit_vector = tf.zeros(shape=(batch_size, edit_dim)) else: if draw_edit: edit_vector = random_noise_encoder(batch_size, edit_dim, norm_max) else: edit_vector = accumulator_encoder(insert_word_embeds, delete_word_embeds, iw_len, dw_len, edit_dim, lamb_reg, norm_eps, norm_max, dropout_keep, enable_vae=enable_vae) wa_inserted, wa_deleted = (tf.constant([[0]]), tf.constant( [[0]])), (tf.constant([[0]]), tf.constant([[0]])) # [batch x agenda_dim] base_agenda = agn.linear(base_sent_embed, edit_vector, agenda_dim) train_dec_inp, train_dec_inp_len, \ train_dec_out, train_dec_out_len = prepare_decoder_input_output(output_words, extended_output_words, output_len) train_dec_inp_extended = prepare_decoder_inputs(extended_output_words, tf.cast(-1, tf.int64)) train_decoder = decoder.train_decoder( base_agenda, embeddings, extended_base_words, oov, train_dec_inp, train_dec_inp_extended, base_sent_hidden_states, wa_inserted, wa_deleted, train_dec_inp_len, base_len, src_len, tgt_len, vocab_size, attn_dim, hidden_dim, num_decoder_layers, swap_memory, enable_dropout=use_dropout, dropout_keep=dropout_keep, no_insert_delete_attn=no_insert_delete_attn) if use_beam_decoder: infr_decoder = decoder.beam_eval_decoder( base_agenda, embeddings, extended_base_words, oov, vocab.get_token_id(vocab.START_TOKEN), vocab.get_token_id(vocab.STOP_TOKEN), base_sent_hidden_states, wa_inserted, wa_deleted, base_len, src_len, tgt_len, vocab_size, attn_dim, hidden_dim, num_decoder_layers, max_sent_length, beam_width, swap_memory, enable_dropout=use_dropout, dropout_keep=dropout_keep, no_insert_delete_attn=no_insert_delete_attn) else: infr_decoder = decoder.greedy_eval_decoder( base_agenda, embeddings, extended_base_words, oov, vocab.get_token_id(vocab.START_TOKEN), vocab.get_token_id(vocab.STOP_TOKEN), base_sent_hidden_states, wa_inserted, wa_deleted, base_len, src_len, tgt_len, vocab_size, attn_dim, hidden_dim, num_decoder_layers, max_sent_length, swap_memory, enable_dropout=use_dropout, dropout_keep=dropout_keep, no_insert_delete_attn=no_insert_delete_attn) edit_vector_recons = output_words_to_edit_vector( output_word_embeds, output_len, edit_dim, ctx_hidden_dim, ctx_hidden_layer, recons_dense_layers, swap_memory) optimizer.add_reconst_loss(edit_vector, edit_vector_recons) return train_decoder, infr_decoder, train_dec_out, train_dec_out_len
def attn_encoder(source_words, target_words, insert_words, delete_words, source_lengths, target_lengths, iw_lengths, dw_lengths, transformer_params, wa_hidden_dim, meve_hidden_dim, meve_hidden_layers, edit_dim, micro_edit_ev_dim, noise_scaler, norm_eps, norm_max, dropout_keep=1., use_dropout=False, swap_memory=False, enable_vae=True): """ Args: source_words: target_words: insert_words: delete_words: source_lengths: target_lengths: iw_lengths: dw_lengths: ctx_hidden_dim: ctx_hidden_layer: wa_hidden_dim: wa_hidden_layer: edit_dim: noise_scaler: norm_eps: norm_max: dropout_keep: Returns: """ with tf.variable_scope(OPS_NAME): wa_inserted_last, wa_deleted_last = wa_accumulator( insert_words, delete_words, iw_lengths, dw_lengths, wa_hidden_dim) if use_dropout and dropout_keep < 1.: wa_inserted_last = tf.nn.dropout(wa_inserted_last, dropout_keep) wa_deleted_last = tf.nn.dropout(wa_deleted_last, dropout_keep) embedding_matrix = vocab.get_embeddings() embedding_layer = ConcatPosEmbedding( transformer_params.hidden_size, embedding_matrix, transformer_params.pos_encoding_dim) micro_ev_projection = tf.layers.Dense(micro_edit_ev_dim, activation=None, use_bias=True, name='micro_ev_proj') mev_extractor = TransformerMicroEditExtractor(embedding_layer, micro_ev_projection, transformer_params) cnx_tgt, micro_evs_st = mev_extractor(source_words, target_words, source_lengths, target_lengths) cnx_src, micro_evs_ts = mev_extractor(target_words, source_words, target_lengths, source_lengths) micro_ev_encoder = tf.make_template('micro_ev_encoder', context_encoder, hidden_dim=meve_hidden_dim, num_layers=meve_hidden_layers, swap_memory=swap_memory, use_dropout=use_dropout, dropout_keep=dropout_keep) aggreg_mev_st = micro_ev_encoder(micro_evs_st, source_lengths) aggreg_mev_ts = micro_ev_encoder(micro_evs_ts, target_lengths) aggreg_mev_st_last = sequence.last_relevant(aggreg_mev_st, source_lengths) aggreg_mev_ts_last = sequence.last_relevant(aggreg_mev_ts, target_lengths) if use_dropout and dropout_keep < 1.: aggreg_mev_st_last = tf.nn.dropout(aggreg_mev_st_last, dropout_keep) aggreg_mev_ts_last = tf.nn.dropout(aggreg_mev_ts_last, dropout_keep) features = tf.concat([ aggreg_mev_st_last, aggreg_mev_ts_last, wa_inserted_last, wa_deleted_last ], axis=1) edit_vector = tf.layers.dense(features, edit_dim, use_bias=False, name='encoder_ev') if enable_vae: edit_vector = sample_vMF(edit_vector, noise_scaler, norm_eps, norm_max) return edit_vector, (cnx_src, micro_evs_st), (cnx_tgt, micro_evs_ts)
def init_from_embedding_matrix(): embedding_matrix = vocab.get_embeddings() embed_layer = EmbeddingSharedWeights(embedding_matrix) tf.add_to_collection('embed_layer', embed_layer)
def editor_train(base_words, output_words, source_words, target_words, insert_words, delete_words, hidden_dim, agenda_dim, edit_dim, micro_edit_ev_dim, num_heads, num_encoder_layers, num_decoder_layers, attn_dim, beam_width, ctx_hidden_dim, ctx_hidden_layer, wa_hidden_dim, wa_hidden_layer, meve_hidden_dim, meve_hidden_layers, max_sent_length, dropout_keep, lamb_reg, norm_eps, norm_max, kill_edit, draw_edit, swap_memory, use_beam_decoder=False, use_dropout=False, no_insert_delete_attn=False, enable_vae=True): batch_size = tf.shape(source_words)[0] # [batch] base_len = seq.length_pre_embedding(base_words) output_len = seq.length_pre_embedding(output_words) src_len = seq.length_pre_embedding(source_words) tgt_len = seq.length_pre_embedding(target_words) iw_len = seq.length_pre_embedding(insert_words) dw_len = seq.length_pre_embedding(delete_words) # variable of shape [vocab_size, embed_dim] embeddings = vocab.get_embeddings() # [batch x max_len x embed_dim] base_word_embeds = vocab.embed_tokens(base_words) src_word_embeds = vocab.embed_tokens(source_words) tgt_word_embeds = vocab.embed_tokens(target_words) insert_word_embeds = vocab.embed_tokens(insert_words) delete_word_embeds = vocab.embed_tokens(delete_words) sent_encoder = tf.make_template('sent_encoder', encoder.source_sent_encoder, hidden_dim=hidden_dim, num_layer=num_encoder_layers, swap_memory=swap_memory, use_dropout=use_dropout, dropout_keep=dropout_keep) # [batch x max_len x rnn_out_dim], [batch x rnn_out_dim] base_sent_hidden_states, base_sent_embed = sent_encoder( base_word_embeds, base_len) assert kill_edit == False and draw_edit == False # [batch x edit_dim] if kill_edit: edit_vector = tf.zeros(shape=(batch_size, edit_dim)) else: if draw_edit: edit_vector = random_noise_encoder(batch_size, edit_dim, norm_max) else: edit_vector, wa_inserted, wa_deleted = attn_encoder( src_word_embeds, tgt_word_embeds, insert_word_embeds, delete_word_embeds, src_len, tgt_len, iw_len, dw_len, ctx_hidden_dim, ctx_hidden_layer, wa_hidden_dim, wa_hidden_layer, meve_hidden_dim, meve_hidden_layers, edit_dim, micro_edit_ev_dim, num_heads, lamb_reg, norm_eps, norm_max, sent_encoder, use_dropout=use_dropout, dropout_keep=dropout_keep, swap_memory=swap_memory, enable_vae=enable_vae) # [batch x agenda_dim] input_agenda = agn.linear(base_sent_embed, edit_vector, agenda_dim) train_dec_inp, train_dec_inp_len, \ train_dec_out, train_dec_out_len = prepare_decoder_input_output(output_words, output_len, None) train_decoder = decoder.train_decoder( input_agenda, embeddings, train_dec_inp, base_sent_hidden_states, wa_inserted, wa_deleted, train_dec_inp_len, base_len, iw_len, dw_len, attn_dim, hidden_dim, num_decoder_layers, swap_memory, enable_dropout=use_dropout, dropout_keep=dropout_keep, no_insert_delete_attn=no_insert_delete_attn) if use_beam_decoder: infr_decoder = decoder.beam_eval_decoder( input_agenda, embeddings, vocab.get_token_id(vocab.START_TOKEN), vocab.get_token_id(vocab.STOP_TOKEN), base_sent_hidden_states, wa_inserted, wa_deleted, base_len, iw_len, dw_len, attn_dim, hidden_dim, num_decoder_layers, max_sent_length, beam_width, swap_memory, enable_dropout=use_dropout, dropout_keep=dropout_keep, no_insert_delete_attn=no_insert_delete_attn) else: infr_decoder = decoder.greedy_eval_decoder( input_agenda, embeddings, vocab.get_token_id(vocab.START_TOKEN), vocab.get_token_id(vocab.STOP_TOKEN), base_sent_hidden_states, wa_inserted, wa_deleted, base_len, iw_len, dw_len, attn_dim, hidden_dim, num_decoder_layers, max_sent_length, swap_memory, enable_dropout=use_dropout, dropout_keep=dropout_keep, no_insert_delete_attn=no_insert_delete_attn) return train_decoder, infr_decoder, train_dec_out, train_dec_out_len
def editor_train(base_words, extended_base_words, output_words, extended_output_words, source_words, target_words, insert_words, delete_words, oov, vocab_size, hidden_dim, agenda_dim, edit_dim, micro_edit_ev_dim, num_heads, num_encoder_layers, num_decoder_layers, attn_dim, beam_width, ctx_hidden_dim, ctx_hidden_layer, wa_hidden_dim, wa_hidden_layer, meve_hidden_dim, meve_hidden_layers, max_sent_length, dropout_keep, lamb_reg, norm_eps, norm_max, kill_edit, draw_edit, swap_memory, use_beam_decoder=False, use_dropout=False, no_insert_delete_attn=False, enable_vae=True): # [batch] base_len = seq.length_pre_embedding(base_words) output_len = seq.length_pre_embedding(extended_output_words) # variable of shape [vocab_size, embed_dim] embeddings = vocab.get_embeddings() # [batch x max_len x embed_dim] base_word_embeds = vocab.embed_tokens(base_words) # [batch x max_len x rnn_out_dim], [batch x rnn_out_dim] base_sent_hidden_states, base_sent_embed = encoder.source_sent_encoder( base_word_embeds, base_len, hidden_dim, num_encoder_layers, use_dropout=use_dropout, dropout_keep=dropout_keep, swap_memory=swap_memory) assert kill_edit == False and draw_edit == False # [batch x agenda_dim] base_agenda = linear(base_sent_embed, agenda_dim) train_dec_inp, train_dec_inp_len, \ train_dec_out, train_dec_out_len = prepare_decoder_input_output(output_words, extended_output_words, output_len) train_dec_inp_extended = prepare_decoder_inputs(extended_output_words, tf.cast(-1, tf.int64)) train_decoder = decoder.train_decoder( base_agenda, embeddings, extended_base_words, oov, train_dec_inp, train_dec_inp_extended, base_sent_hidden_states, train_dec_inp_len, base_len, vocab_size, attn_dim, hidden_dim, num_decoder_layers, swap_memory, enable_dropout=use_dropout, dropout_keep=dropout_keep, no_insert_delete_attn=no_insert_delete_attn) if use_beam_decoder: infr_decoder = decoder.beam_eval_decoder( base_agenda, embeddings, extended_base_words, oov, vocab.get_token_id(vocab.START_TOKEN), vocab.get_token_id(vocab.STOP_TOKEN), base_sent_hidden_states, base_len, vocab_size, attn_dim, hidden_dim, num_decoder_layers, max_sent_length, beam_width, swap_memory, enable_dropout=use_dropout, dropout_keep=dropout_keep, no_insert_delete_attn=no_insert_delete_attn) else: infr_decoder = decoder.greedy_eval_decoder( base_agenda, embeddings, extended_base_words, oov, vocab.get_token_id(vocab.START_TOKEN), vocab.get_token_id(vocab.STOP_TOKEN), base_sent_hidden_states, base_len, vocab_size, attn_dim, hidden_dim, num_decoder_layers, max_sent_length, swap_memory, enable_dropout=use_dropout, dropout_keep=dropout_keep, no_insert_delete_attn=no_insert_delete_attn) add_decoder_attn_history_graph(infr_decoder) return train_decoder, infr_decoder, train_dec_out, train_dec_out_len
def decoder_outputs_to_edit_vector(decoder_output, temperature_starter, decay_rate, decay_steps, edit_dim, enc_hidden_dim, enc_num_layers, dense_layers, swap_memory): with tf.variable_scope(OPS_NAME): # [VOCAB x word_dim] embeddings = vocab.get_embeddings() # Extend embedding matrix to support oov tokens unk_id = vocab.get_token_id(vocab.UNKNOWN_TOKEN) unk_embed = tf.expand_dims(vocab.embed_tokens(unk_id), 0) unk_embeddings = tf.tile(unk_embed, [50, 1]) # [VOCAB+50 x word_dim] embeddings_extended = tf.concat([embeddings, unk_embeddings], axis=0) global_step = tf.train.get_global_step() temperature = tf.train.exponential_decay(temperature_starter, global_step, decay_steps, decay_rate, name='temperature') tf.summary.scalar('temper', temperature, ['extra']) # [batch x max_len x VOCAB+50], softmax probabilities outputs = decoder.rnn_output(decoder_output) # substitute values less than 0 for numerical stability outputs = tf.where(tf.less_equal(outputs, 0), tf.ones_like(outputs) * 1e-10, outputs) # convert softmax probabilities to one_hot vectors dist = tfd.RelaxedOneHotCategorical(temperature, probs=outputs) # [batch x max_len x VOCAB+50], one_hot outputs_one_hot = dist.sample() # [batch x max_len x word_dim], one_hot^T * embedding_matrix outputs_embed = tf.einsum("btv,vd-> btd", outputs_one_hot, embeddings_extended) # [batch] outputs_length = decoder.seq_length(decoder_output) # [batch x max_len x hidden], [batch x hidden] hidden_states, sentence_embedding = encoder.source_sent_encoder( outputs_embed, outputs_length, enc_hidden_dim, enc_num_layers, use_dropout=False, dropout_keep=1.0, swap_memory=swap_memory) h = sentence_embedding for l in dense_layers: h = tf.layers.dense(h, l, activation='relu', name='hidden_%s' % (l)) # [batch x edit_dim] edit_vector = tf.layers.dense(h, edit_dim, activation=None, name='linear') return edit_vector