def train_decoder(agenda, embeddings, extended_base_words, oov, dec_inputs, dec_extended_inputs, base_sent_hiddens, insert_word_embeds, delete_word_embeds, dec_input_lengths, base_length, iw_length, dw_length, vocab_size, attn_dim, hidden_dim, num_layer, swap_memory, enable_dropout=False, dropout_keep=1., no_insert_delete_attn=False): with tf.variable_scope(OPS_NAME, 'decoder'): dec_input_embeds = vocab.embed_tokens(dec_inputs) last_ids = tf.cast(tf.expand_dims(dec_extended_inputs, 2), tf.float32) cell_input = tf.concat([dec_input_embeds, last_ids], axis=2) helper = seq2seq.TrainingHelper(cell_input, dec_input_lengths, name='train_helper') cell, zero_states = create_decoder_cell( agenda, extended_base_words, oov, base_sent_hiddens, insert_word_embeds, delete_word_embeds, base_length, iw_length, dw_length, vocab_size, attn_dim, hidden_dim, num_layer, enable_dropout=enable_dropout, dropout_keep=dropout_keep, no_insert_delete_attn=no_insert_delete_attn) decoder = seq2seq.BasicDecoder(cell, helper, zero_states) outputs, state, length = seq2seq.dynamic_decode( decoder, swap_memory=swap_memory) return outputs, state, length
def fn(orig_ids): orig_ids = tf.cast(orig_ids, tf.int64) in_vocab_ids = tf.where(tf.less(orig_ids, vocab_size), orig_ids, tf.ones_like(orig_ids) * vocab.OOV_TOKEN_ID) embeds = vocab.embed_tokens(in_vocab_ids) last_ids = tf.where( tf.equal(orig_ids, vocab.get_token_id(vocab.START_TOKEN)), tf.ones_like(orig_ids) * -1, orig_ids) last_ids = tf.cast(tf.expand_dims(last_ids, 2), tf.float32) cell_input = tf.concat([embeds, last_ids], axis=2) return cell_input
def prepare_output_embed( decoder_output, temperature_starter, decay_rate, decay_steps, ): # [VOCAB x word_dim] embeddings = vocab.get_embeddings() # Extend embedding matrix to support oov tokens unk_id = vocab.get_token_id(vocab.UNKNOWN_TOKEN) unk_embed = tf.expand_dims(vocab.embed_tokens(unk_id), 0) unk_embeddings = tf.tile(unk_embed, [50, 1]) # [VOCAB+50 x word_dim] embeddings_extended = tf.concat([embeddings, unk_embeddings], axis=0) global_step = tf.train.get_global_step() temperature = tf.train.exponential_decay(temperature_starter, global_step, decay_steps, decay_rate, name='temperature') tf.summary.scalar('temper', temperature, ['extra']) # [batch x max_len x VOCAB+50], softmax probabilities outputs = decoder.rnn_output(decoder_output) # substitute values less than 0 for numerical stability outputs = tf.where(tf.less_equal(outputs, 0), tf.ones_like(outputs) * 1e-10, outputs) # convert softmax probabilities to one_hot vectors dist = tfd.RelaxedOneHotCategorical(temperature, probs=outputs) # [batch x max_len x VOCAB+50], one_hot outputs_one_hot = dist.sample() # [batch x max_len x word_dim], one_hot^T * embedding_matrix outputs_embed = tf.einsum("btv,vd-> btd", outputs_one_hot, embeddings_extended) return outputs_embed
def fn(orig_ids): orig_ids = tf.cast(orig_ids, tf.int64) in_vocab_ids = tf.where(tf.less(orig_ids, vocab_size), orig_ids, tf.ones_like(orig_ids) * vocab.OOV_TOKEN_ID) embeds = vocab.embed_tokens(in_vocab_ids) return embeds
def editor_train(base_words, source_words, target_words, insert_words, delete_words, embed_matrix, vocab_table, hidden_dim, agenda_dim, edit_dim, num_encoder_layers, num_decoder_layers, attn_dim, beam_width, max_sent_length, dropout_keep, lamb_reg, norm_eps, norm_max, kill_edit, draw_edit, swap_memory, use_beam_decoder=False, use_dropout=False, no_insert_delete_attn=False, enable_vae=True): batch_size = tf.shape(source_words)[0] # [batch] base_len = seq.length_pre_embedding(base_words) src_len = seq.length_pre_embedding(source_words) tgt_len = seq.length_pre_embedding(target_words) iw_len = seq.length_pre_embedding(insert_words) dw_len = seq.length_pre_embedding(delete_words) # variable of shape [vocab_size, embed_dim] embeddings = vocab.get_embeddings() # [batch x max_len x embed_dim] base_word_embeds = vocab.embed_tokens(base_words) # src_word_embeds = vocab.embed_tokens(source_words) # tgt_word_embeds = vocab.embed_tokens(target_words) insert_word_embeds = vocab.embed_tokens(insert_words) delete_word_embeds = vocab.embed_tokens(delete_words) # [batch x max_len x rnn_out_dim], [batch x rnn_out_dim] base_sent_hidden_states, base_sent_embed = encoder.source_sent_encoder( base_word_embeds, base_len, hidden_dim, num_encoder_layers, use_dropout=use_dropout, dropout_keep=dropout_keep, swap_memory=swap_memory) # [batch x edit_dim] if kill_edit: edit_vector = tf.zeros(shape=(batch_size, edit_dim)) else: if draw_edit: edit_vector = edit_encoder.random_noise_encoder( batch_size, edit_dim, norm_max) else: edit_vector = edit_encoder.accumulator_encoder( insert_word_embeds, delete_word_embeds, iw_len, dw_len, edit_dim, lamb_reg, norm_eps, norm_max, dropout_keep, enable_vae=enable_vae) # [batch x agenda_dim] input_agenda = agn.linear(base_sent_embed, edit_vector, agenda_dim) train_dec_inp, train_dec_inp_len, \ train_dec_out, train_dec_out_len = prepare_decoder_input_output(target_words, tgt_len, vocab_table) train_decoder = decoder.train_decoder( input_agenda, embeddings, train_dec_inp, base_sent_hidden_states, insert_word_embeds, delete_word_embeds, train_dec_inp_len, base_len, iw_len, dw_len, attn_dim, hidden_dim, num_decoder_layers, swap_memory, enable_dropout=use_dropout, dropout_keep=dropout_keep, no_insert_delete_attn=no_insert_delete_attn) if use_beam_decoder: infr_decoder = decoder.beam_eval_decoder( input_agenda, embeddings, vocab.get_token_id(vocab.START_TOKEN, vocab_table), vocab.get_token_id(vocab.STOP_TOKEN, vocab_table), base_sent_hidden_states, insert_word_embeds, delete_word_embeds, base_len, iw_len, dw_len, attn_dim, hidden_dim, num_decoder_layers, max_sent_length, beam_width, swap_memory, enable_dropout=use_dropout, dropout_keep=dropout_keep, no_insert_delete_attn=no_insert_delete_attn) else: infr_decoder = decoder.greedy_eval_decoder( input_agenda, embeddings, vocab.get_token_id(vocab.START_TOKEN, vocab_table), vocab.get_token_id(vocab.STOP_TOKEN, vocab_table), base_sent_hidden_states, insert_word_embeds, delete_word_embeds, base_len, iw_len, dw_len, attn_dim, hidden_dim, num_decoder_layers, max_sent_length, swap_memory, enable_dropout=use_dropout, dropout_keep=dropout_keep, no_insert_delete_attn=no_insert_delete_attn) return train_decoder, infr_decoder, train_dec_out, train_dec_out_len
def editor_train(base_words, extended_base_words, output_words, extended_output_words, source_words, target_words, insert_words, delete_words, oov, vocab_size, hidden_dim, agenda_dim, edit_dim, micro_edit_ev_dim, num_heads, num_encoder_layers, num_decoder_layers, attn_dim, beam_width, ctx_hidden_dim, ctx_hidden_layer, wa_hidden_dim, wa_hidden_layer, meve_hidden_dim, meve_hidden_layers, recons_dense_layers, max_sent_length, dropout_keep, lamb_reg, norm_eps, norm_max, kill_edit, draw_edit, swap_memory, use_beam_decoder=False, use_dropout=False, no_insert_delete_attn=False, enable_vae=True): batch_size = tf.shape(source_words)[0] # [batch] base_len = seq.length_pre_embedding(base_words) output_len = seq.length_pre_embedding(extended_output_words) src_len = seq.length_pre_embedding(source_words) tgt_len = seq.length_pre_embedding(target_words) iw_len = seq.length_pre_embedding(insert_words) dw_len = seq.length_pre_embedding(delete_words) # variable of shape [vocab_size, embed_dim] embeddings = vocab.get_embeddings() # [batch x max_len x embed_dim] base_word_embeds = vocab.embed_tokens(base_words) output_word_embeds = vocab.embed_tokens(output_words) # src_word_embeds = vocab.embed_tokens(source_words) # tgt_word_embeds = vocab.embed_tokens(target_words) insert_word_embeds = vocab.embed_tokens(insert_words) delete_word_embeds = vocab.embed_tokens(delete_words) # [batch x max_len x rnn_out_dim], [batch x rnn_out_dim] base_sent_hidden_states, base_sent_embed = encoder.source_sent_encoder( base_word_embeds, base_len, hidden_dim, num_encoder_layers, use_dropout=use_dropout, dropout_keep=dropout_keep, swap_memory=swap_memory) assert kill_edit == False and draw_edit == False # [batch x edit_dim] if kill_edit: edit_vector = tf.zeros(shape=(batch_size, edit_dim)) else: if draw_edit: edit_vector = random_noise_encoder(batch_size, edit_dim, norm_max) else: edit_vector = accumulator_encoder(insert_word_embeds, delete_word_embeds, iw_len, dw_len, edit_dim, lamb_reg, norm_eps, norm_max, dropout_keep, enable_vae=enable_vae) wa_inserted, wa_deleted = (tf.constant([[0]]), tf.constant( [[0]])), (tf.constant([[0]]), tf.constant([[0]])) # [batch x agenda_dim] base_agenda = agn.linear(base_sent_embed, edit_vector, agenda_dim) train_dec_inp, train_dec_inp_len, \ train_dec_out, train_dec_out_len = prepare_decoder_input_output(output_words, extended_output_words, output_len) train_dec_inp_extended = prepare_decoder_inputs(extended_output_words, tf.cast(-1, tf.int64)) train_decoder = decoder.train_decoder( base_agenda, embeddings, extended_base_words, oov, train_dec_inp, train_dec_inp_extended, base_sent_hidden_states, wa_inserted, wa_deleted, train_dec_inp_len, base_len, src_len, tgt_len, vocab_size, attn_dim, hidden_dim, num_decoder_layers, swap_memory, enable_dropout=use_dropout, dropout_keep=dropout_keep, no_insert_delete_attn=no_insert_delete_attn) if use_beam_decoder: infr_decoder = decoder.beam_eval_decoder( base_agenda, embeddings, extended_base_words, oov, vocab.get_token_id(vocab.START_TOKEN), vocab.get_token_id(vocab.STOP_TOKEN), base_sent_hidden_states, wa_inserted, wa_deleted, base_len, src_len, tgt_len, vocab_size, attn_dim, hidden_dim, num_decoder_layers, max_sent_length, beam_width, swap_memory, enable_dropout=use_dropout, dropout_keep=dropout_keep, no_insert_delete_attn=no_insert_delete_attn) else: infr_decoder = decoder.greedy_eval_decoder( base_agenda, embeddings, extended_base_words, oov, vocab.get_token_id(vocab.START_TOKEN), vocab.get_token_id(vocab.STOP_TOKEN), base_sent_hidden_states, wa_inserted, wa_deleted, base_len, src_len, tgt_len, vocab_size, attn_dim, hidden_dim, num_decoder_layers, max_sent_length, swap_memory, enable_dropout=use_dropout, dropout_keep=dropout_keep, no_insert_delete_attn=no_insert_delete_attn) edit_vector_recons = output_words_to_edit_vector( output_word_embeds, output_len, edit_dim, ctx_hidden_dim, ctx_hidden_layer, recons_dense_layers, swap_memory) optimizer.add_reconst_loss(edit_vector, edit_vector_recons) return train_decoder, infr_decoder, train_dec_out, train_dec_out_len
def _get_input_embeddings(self, ids): ids = tf.where(tf.less(ids, self.vocab_size), ids, tf.ones_like(ids) * vocab.OOV_TOKEN_ID) return vocab.embed_tokens(ids)
def editor_train(base_words, output_words, source_words, target_words, insert_words, delete_words, hidden_dim, agenda_dim, edit_dim, micro_edit_ev_dim, num_heads, num_encoder_layers, num_decoder_layers, attn_dim, beam_width, ctx_hidden_dim, ctx_hidden_layer, wa_hidden_dim, wa_hidden_layer, meve_hidden_dim, meve_hidden_layers, max_sent_length, dropout_keep, lamb_reg, norm_eps, norm_max, kill_edit, draw_edit, swap_memory, use_beam_decoder=False, use_dropout=False, no_insert_delete_attn=False, enable_vae=True): batch_size = tf.shape(source_words)[0] # [batch] base_len = seq.length_pre_embedding(base_words) output_len = seq.length_pre_embedding(output_words) src_len = seq.length_pre_embedding(source_words) tgt_len = seq.length_pre_embedding(target_words) iw_len = seq.length_pre_embedding(insert_words) dw_len = seq.length_pre_embedding(delete_words) # variable of shape [vocab_size, embed_dim] embeddings = vocab.get_embeddings() # [batch x max_len x embed_dim] base_word_embeds = vocab.embed_tokens(base_words) src_word_embeds = vocab.embed_tokens(source_words) tgt_word_embeds = vocab.embed_tokens(target_words) insert_word_embeds = vocab.embed_tokens(insert_words) delete_word_embeds = vocab.embed_tokens(delete_words) sent_encoder = tf.make_template('sent_encoder', encoder.source_sent_encoder, hidden_dim=hidden_dim, num_layer=num_encoder_layers, swap_memory=swap_memory, use_dropout=use_dropout, dropout_keep=dropout_keep) # [batch x max_len x rnn_out_dim], [batch x rnn_out_dim] base_sent_hidden_states, base_sent_embed = sent_encoder( base_word_embeds, base_len) assert kill_edit == False and draw_edit == False # [batch x edit_dim] if kill_edit: edit_vector = tf.zeros(shape=(batch_size, edit_dim)) else: if draw_edit: edit_vector = random_noise_encoder(batch_size, edit_dim, norm_max) else: edit_vector, wa_inserted, wa_deleted = attn_encoder( src_word_embeds, tgt_word_embeds, insert_word_embeds, delete_word_embeds, src_len, tgt_len, iw_len, dw_len, ctx_hidden_dim, ctx_hidden_layer, wa_hidden_dim, wa_hidden_layer, meve_hidden_dim, meve_hidden_layers, edit_dim, micro_edit_ev_dim, num_heads, lamb_reg, norm_eps, norm_max, sent_encoder, use_dropout=use_dropout, dropout_keep=dropout_keep, swap_memory=swap_memory, enable_vae=enable_vae) # [batch x agenda_dim] input_agenda = agn.linear(base_sent_embed, edit_vector, agenda_dim) train_dec_inp, train_dec_inp_len, \ train_dec_out, train_dec_out_len = prepare_decoder_input_output(output_words, output_len, None) train_decoder = decoder.train_decoder( input_agenda, embeddings, train_dec_inp, base_sent_hidden_states, wa_inserted, wa_deleted, train_dec_inp_len, base_len, iw_len, dw_len, attn_dim, hidden_dim, num_decoder_layers, swap_memory, enable_dropout=use_dropout, dropout_keep=dropout_keep, no_insert_delete_attn=no_insert_delete_attn) if use_beam_decoder: infr_decoder = decoder.beam_eval_decoder( input_agenda, embeddings, vocab.get_token_id(vocab.START_TOKEN), vocab.get_token_id(vocab.STOP_TOKEN), base_sent_hidden_states, wa_inserted, wa_deleted, base_len, iw_len, dw_len, attn_dim, hidden_dim, num_decoder_layers, max_sent_length, beam_width, swap_memory, enable_dropout=use_dropout, dropout_keep=dropout_keep, no_insert_delete_attn=no_insert_delete_attn) else: infr_decoder = decoder.greedy_eval_decoder( input_agenda, embeddings, vocab.get_token_id(vocab.START_TOKEN), vocab.get_token_id(vocab.STOP_TOKEN), base_sent_hidden_states, wa_inserted, wa_deleted, base_len, iw_len, dw_len, attn_dim, hidden_dim, num_decoder_layers, max_sent_length, swap_memory, enable_dropout=use_dropout, dropout_keep=dropout_keep, no_insert_delete_attn=no_insert_delete_attn) return train_decoder, infr_decoder, train_dec_out, train_dec_out_len
def editor_train(base_words, extended_base_words, output_words, extended_output_words, source_words, target_words, insert_words, delete_words, oov, vocab_size, hidden_dim, agenda_dim, edit_dim, micro_edit_ev_dim, num_heads, num_encoder_layers, num_decoder_layers, attn_dim, beam_width, ctx_hidden_dim, ctx_hidden_layer, wa_hidden_dim, wa_hidden_layer, meve_hidden_dim, meve_hidden_layers, max_sent_length, dropout_keep, lamb_reg, norm_eps, norm_max, kill_edit, draw_edit, swap_memory, use_beam_decoder=False, use_dropout=False, no_insert_delete_attn=False, enable_vae=True): # [batch] base_len = seq.length_pre_embedding(base_words) output_len = seq.length_pre_embedding(extended_output_words) # variable of shape [vocab_size, embed_dim] embeddings = vocab.get_embeddings() # [batch x max_len x embed_dim] base_word_embeds = vocab.embed_tokens(base_words) # [batch x max_len x rnn_out_dim], [batch x rnn_out_dim] base_sent_hidden_states, base_sent_embed = encoder.source_sent_encoder( base_word_embeds, base_len, hidden_dim, num_encoder_layers, use_dropout=use_dropout, dropout_keep=dropout_keep, swap_memory=swap_memory) assert kill_edit == False and draw_edit == False # [batch x agenda_dim] base_agenda = linear(base_sent_embed, agenda_dim) train_dec_inp, train_dec_inp_len, \ train_dec_out, train_dec_out_len = prepare_decoder_input_output(output_words, extended_output_words, output_len) train_dec_inp_extended = prepare_decoder_inputs(extended_output_words, tf.cast(-1, tf.int64)) train_decoder = decoder.train_decoder( base_agenda, embeddings, extended_base_words, oov, train_dec_inp, train_dec_inp_extended, base_sent_hidden_states, train_dec_inp_len, base_len, vocab_size, attn_dim, hidden_dim, num_decoder_layers, swap_memory, enable_dropout=use_dropout, dropout_keep=dropout_keep, no_insert_delete_attn=no_insert_delete_attn) if use_beam_decoder: infr_decoder = decoder.beam_eval_decoder( base_agenda, embeddings, extended_base_words, oov, vocab.get_token_id(vocab.START_TOKEN), vocab.get_token_id(vocab.STOP_TOKEN), base_sent_hidden_states, base_len, vocab_size, attn_dim, hidden_dim, num_decoder_layers, max_sent_length, beam_width, swap_memory, enable_dropout=use_dropout, dropout_keep=dropout_keep, no_insert_delete_attn=no_insert_delete_attn) else: infr_decoder = decoder.greedy_eval_decoder( base_agenda, embeddings, extended_base_words, oov, vocab.get_token_id(vocab.START_TOKEN), vocab.get_token_id(vocab.STOP_TOKEN), base_sent_hidden_states, base_len, vocab_size, attn_dim, hidden_dim, num_decoder_layers, max_sent_length, swap_memory, enable_dropout=use_dropout, dropout_keep=dropout_keep, no_insert_delete_attn=no_insert_delete_attn) add_decoder_attn_history_graph(infr_decoder) return train_decoder, infr_decoder, train_dec_out, train_dec_out_len
def decoder_outputs_to_edit_vector(decoder_output, temperature_starter, decay_rate, decay_steps, edit_dim, enc_hidden_dim, enc_num_layers, dense_layers, swap_memory): with tf.variable_scope(OPS_NAME): # [VOCAB x word_dim] embeddings = vocab.get_embeddings() # Extend embedding matrix to support oov tokens unk_id = vocab.get_token_id(vocab.UNKNOWN_TOKEN) unk_embed = tf.expand_dims(vocab.embed_tokens(unk_id), 0) unk_embeddings = tf.tile(unk_embed, [50, 1]) # [VOCAB+50 x word_dim] embeddings_extended = tf.concat([embeddings, unk_embeddings], axis=0) global_step = tf.train.get_global_step() temperature = tf.train.exponential_decay(temperature_starter, global_step, decay_steps, decay_rate, name='temperature') tf.summary.scalar('temper', temperature, ['extra']) # [batch x max_len x VOCAB+50], softmax probabilities outputs = decoder.rnn_output(decoder_output) # substitute values less than 0 for numerical stability outputs = tf.where(tf.less_equal(outputs, 0), tf.ones_like(outputs) * 1e-10, outputs) # convert softmax probabilities to one_hot vectors dist = tfd.RelaxedOneHotCategorical(temperature, probs=outputs) # [batch x max_len x VOCAB+50], one_hot outputs_one_hot = dist.sample() # [batch x max_len x word_dim], one_hot^T * embedding_matrix outputs_embed = tf.einsum("btv,vd-> btd", outputs_one_hot, embeddings_extended) # [batch] outputs_length = decoder.seq_length(decoder_output) # [batch x max_len x hidden], [batch x hidden] hidden_states, sentence_embedding = encoder.source_sent_encoder( outputs_embed, outputs_length, enc_hidden_dim, enc_num_layers, use_dropout=False, dropout_keep=1.0, swap_memory=swap_memory) h = sentence_embedding for l in dense_layers: h = tf.layers.dense(h, l, activation='relu', name='hidden_%s' % (l)) # [batch x edit_dim] edit_vector = tf.layers.dense(h, edit_dim, activation=None, name='linear') return edit_vector