def output_words_to_edit_vector(output_words_embed, output_words_len, edit_dim, enc_hidden_dim, enc_num_layers, dense_layers, swap_memory): with tf.variable_scope(OPS_NAME): hidden_states, seq_embedding = encoder.source_sent_encoder( output_words_embed, output_words_len, enc_hidden_dim, enc_num_layers, use_dropout=False, dropout_keep=1.0, swap_memory=swap_memory) h = seq_embedding for l in dense_layers: h = tf.layers.dense(h, l, activation='relu', name='hidden_%s' % (l)) edit_vector = tf.layers.dense(h, edit_dim, activation=None, name='linear') return edit_vector
def editor_train(base_words, source_words, target_words, insert_words, delete_words, embed_matrix, vocab_table, hidden_dim, agenda_dim, edit_dim, num_encoder_layers, num_decoder_layers, attn_dim, beam_width, max_sent_length, dropout_keep, lamb_reg, norm_eps, norm_max, kill_edit, draw_edit, swap_memory, use_beam_decoder=False, use_dropout=False, no_insert_delete_attn=False, enable_vae=True): batch_size = tf.shape(source_words)[0] # [batch] base_len = seq.length_pre_embedding(base_words) src_len = seq.length_pre_embedding(source_words) tgt_len = seq.length_pre_embedding(target_words) iw_len = seq.length_pre_embedding(insert_words) dw_len = seq.length_pre_embedding(delete_words) # variable of shape [vocab_size, embed_dim] embeddings = vocab.get_embeddings() # [batch x max_len x embed_dim] base_word_embeds = vocab.embed_tokens(base_words) # src_word_embeds = vocab.embed_tokens(source_words) # tgt_word_embeds = vocab.embed_tokens(target_words) insert_word_embeds = vocab.embed_tokens(insert_words) delete_word_embeds = vocab.embed_tokens(delete_words) # [batch x max_len x rnn_out_dim], [batch x rnn_out_dim] base_sent_hidden_states, base_sent_embed = encoder.source_sent_encoder( base_word_embeds, base_len, hidden_dim, num_encoder_layers, use_dropout=use_dropout, dropout_keep=dropout_keep, swap_memory=swap_memory) # [batch x edit_dim] if kill_edit: edit_vector = tf.zeros(shape=(batch_size, edit_dim)) else: if draw_edit: edit_vector = edit_encoder.random_noise_encoder( batch_size, edit_dim, norm_max) else: edit_vector = edit_encoder.accumulator_encoder( insert_word_embeds, delete_word_embeds, iw_len, dw_len, edit_dim, lamb_reg, norm_eps, norm_max, dropout_keep, enable_vae=enable_vae) # [batch x agenda_dim] input_agenda = agn.linear(base_sent_embed, edit_vector, agenda_dim) train_dec_inp, train_dec_inp_len, \ train_dec_out, train_dec_out_len = prepare_decoder_input_output(target_words, tgt_len, vocab_table) train_decoder = decoder.train_decoder( input_agenda, embeddings, train_dec_inp, base_sent_hidden_states, insert_word_embeds, delete_word_embeds, train_dec_inp_len, base_len, iw_len, dw_len, attn_dim, hidden_dim, num_decoder_layers, swap_memory, enable_dropout=use_dropout, dropout_keep=dropout_keep, no_insert_delete_attn=no_insert_delete_attn) if use_beam_decoder: infr_decoder = decoder.beam_eval_decoder( input_agenda, embeddings, vocab.get_token_id(vocab.START_TOKEN, vocab_table), vocab.get_token_id(vocab.STOP_TOKEN, vocab_table), base_sent_hidden_states, insert_word_embeds, delete_word_embeds, base_len, iw_len, dw_len, attn_dim, hidden_dim, num_decoder_layers, max_sent_length, beam_width, swap_memory, enable_dropout=use_dropout, dropout_keep=dropout_keep, no_insert_delete_attn=no_insert_delete_attn) else: infr_decoder = decoder.greedy_eval_decoder( input_agenda, embeddings, vocab.get_token_id(vocab.START_TOKEN, vocab_table), vocab.get_token_id(vocab.STOP_TOKEN, vocab_table), base_sent_hidden_states, insert_word_embeds, delete_word_embeds, base_len, iw_len, dw_len, attn_dim, hidden_dim, num_decoder_layers, max_sent_length, swap_memory, enable_dropout=use_dropout, dropout_keep=dropout_keep, no_insert_delete_attn=no_insert_delete_attn) return train_decoder, infr_decoder, train_dec_out, train_dec_out_len
def editor_train(base_words, extended_base_words, output_words, extended_output_words, source_words, target_words, insert_words, delete_words, oov, vocab_size, hidden_dim, agenda_dim, edit_dim, micro_edit_ev_dim, num_heads, num_encoder_layers, num_decoder_layers, attn_dim, beam_width, ctx_hidden_dim, ctx_hidden_layer, wa_hidden_dim, wa_hidden_layer, meve_hidden_dim, meve_hidden_layers, recons_dense_layers, max_sent_length, dropout_keep, lamb_reg, norm_eps, norm_max, kill_edit, draw_edit, swap_memory, use_beam_decoder=False, use_dropout=False, no_insert_delete_attn=False, enable_vae=True): batch_size = tf.shape(source_words)[0] # [batch] base_len = seq.length_pre_embedding(base_words) output_len = seq.length_pre_embedding(extended_output_words) src_len = seq.length_pre_embedding(source_words) tgt_len = seq.length_pre_embedding(target_words) iw_len = seq.length_pre_embedding(insert_words) dw_len = seq.length_pre_embedding(delete_words) # variable of shape [vocab_size, embed_dim] embeddings = vocab.get_embeddings() # [batch x max_len x embed_dim] base_word_embeds = vocab.embed_tokens(base_words) output_word_embeds = vocab.embed_tokens(output_words) # src_word_embeds = vocab.embed_tokens(source_words) # tgt_word_embeds = vocab.embed_tokens(target_words) insert_word_embeds = vocab.embed_tokens(insert_words) delete_word_embeds = vocab.embed_tokens(delete_words) # [batch x max_len x rnn_out_dim], [batch x rnn_out_dim] base_sent_hidden_states, base_sent_embed = encoder.source_sent_encoder( base_word_embeds, base_len, hidden_dim, num_encoder_layers, use_dropout=use_dropout, dropout_keep=dropout_keep, swap_memory=swap_memory) assert kill_edit == False and draw_edit == False # [batch x edit_dim] if kill_edit: edit_vector = tf.zeros(shape=(batch_size, edit_dim)) else: if draw_edit: edit_vector = random_noise_encoder(batch_size, edit_dim, norm_max) else: edit_vector = accumulator_encoder(insert_word_embeds, delete_word_embeds, iw_len, dw_len, edit_dim, lamb_reg, norm_eps, norm_max, dropout_keep, enable_vae=enable_vae) wa_inserted, wa_deleted = (tf.constant([[0]]), tf.constant( [[0]])), (tf.constant([[0]]), tf.constant([[0]])) # [batch x agenda_dim] base_agenda = agn.linear(base_sent_embed, edit_vector, agenda_dim) train_dec_inp, train_dec_inp_len, \ train_dec_out, train_dec_out_len = prepare_decoder_input_output(output_words, extended_output_words, output_len) train_dec_inp_extended = prepare_decoder_inputs(extended_output_words, tf.cast(-1, tf.int64)) train_decoder = decoder.train_decoder( base_agenda, embeddings, extended_base_words, oov, train_dec_inp, train_dec_inp_extended, base_sent_hidden_states, wa_inserted, wa_deleted, train_dec_inp_len, base_len, src_len, tgt_len, vocab_size, attn_dim, hidden_dim, num_decoder_layers, swap_memory, enable_dropout=use_dropout, dropout_keep=dropout_keep, no_insert_delete_attn=no_insert_delete_attn) if use_beam_decoder: infr_decoder = decoder.beam_eval_decoder( base_agenda, embeddings, extended_base_words, oov, vocab.get_token_id(vocab.START_TOKEN), vocab.get_token_id(vocab.STOP_TOKEN), base_sent_hidden_states, wa_inserted, wa_deleted, base_len, src_len, tgt_len, vocab_size, attn_dim, hidden_dim, num_decoder_layers, max_sent_length, beam_width, swap_memory, enable_dropout=use_dropout, dropout_keep=dropout_keep, no_insert_delete_attn=no_insert_delete_attn) else: infr_decoder = decoder.greedy_eval_decoder( base_agenda, embeddings, extended_base_words, oov, vocab.get_token_id(vocab.START_TOKEN), vocab.get_token_id(vocab.STOP_TOKEN), base_sent_hidden_states, wa_inserted, wa_deleted, base_len, src_len, tgt_len, vocab_size, attn_dim, hidden_dim, num_decoder_layers, max_sent_length, swap_memory, enable_dropout=use_dropout, dropout_keep=dropout_keep, no_insert_delete_attn=no_insert_delete_attn) edit_vector_recons = output_words_to_edit_vector( output_word_embeds, output_len, edit_dim, ctx_hidden_dim, ctx_hidden_layer, recons_dense_layers, swap_memory) optimizer.add_reconst_loss(edit_vector, edit_vector_recons) return train_decoder, infr_decoder, train_dec_out, train_dec_out_len
def test_decoder_train(dataset_file, embedding_file): with tf.Graph().as_default(): d_fn, gold_dataset = dataset_file e_fn, gold_embeds = embedding_file v, embed_matrix = vocab.read_word_embeddings(e_fn, EMBED_DIM) vocab_lookup = vocab.get_vocab_lookup(v) stop_token = tf.constant(bytes(vocab.STOP_TOKEN, encoding='utf8'), dtype=tf.string) stop_token_id = vocab_lookup.lookup(stop_token) start_token = tf.constant(bytes(vocab.START_TOKEN, encoding='utf8'), dtype=tf.string) start_token_id = vocab_lookup.lookup(start_token) pad_token = tf.constant(bytes(vocab.PAD_TOKEN, encoding='utf8'), dtype=tf.string) pad_token_id = vocab_lookup.lookup(pad_token) dataset = neural_editor.input_fn(d_fn, vocab_lookup, BATCH_SIZE, NUM_EPOCH) iter = dataset.make_initializable_iterator() (src, tgt, inw, dlw), _ = iter.get_next() src_len = sequence.length_pre_embedding(src) tgt_len = sequence.length_pre_embedding(tgt) dec_inputs = decoder.prepare_decoder_inputs(tgt, start_token_id) dec_outputs = decoder.prepare_decoder_output(tgt, tgt_len, stop_token_id, pad_token_id) dec_inputs_len = sequence.length_pre_embedding(dec_inputs) dec_outputs_len = sequence.length_pre_embedding(dec_outputs) batch_size = tf.shape(src)[0] edit_vector = edit_encoder.random_noise_encoder( batch_size, EDIT_DIM, 14.0) embedding = tf.get_variable( 'embeddings', shape=embed_matrix.shape, initializer=tf.constant_initializer(embed_matrix)) src_embd = tf.nn.embedding_lookup(embedding, src) src_sent_embeds, final_states = encoder.source_sent_encoder( src_embd, src_len, 20, 3, 0.8) agn = agenda.linear(final_states, edit_vector, 4) dec_out = decoder.train_decoder(agn, embedding, dec_inputs, src_sent_embeds, tf.nn.embedding_lookup(embedding, inw), tf.nn.embedding_lookup(embedding, dlw), dec_inputs_len, src_len, sequence.length_pre_embedding(inw), sequence.length_pre_embedding(dlw), 5, 20, 3, False) # eval_dec_out = decoder.greedy_eval_decoder( # agn, embedding, # start_token_id, stop_token_id, # src_sent_embeds, # tf.nn.embedding_lookup(embedding, inw), # tf.nn.embedding_lookup(embedding, dlw), # src_len, sequence.length_pre_embedding(inw), sequence.length_pre_embedding(dlw), # 5, 20, 3, 40 # ) eval_dec_out = decoder.beam_eval_decoder( agn, embedding, start_token_id, stop_token_id, src_sent_embeds, tf.nn.embedding_lookup(embedding, inw), tf.nn.embedding_lookup(embedding, dlw), src_len, sequence.length_pre_embedding(inw), sequence.length_pre_embedding(dlw), 5, 20, 3, 40) # saver = tf.train.Saver(write_version=tf.train.SaverDef.V1) # s = tf.summary.FileWriter('data/an') # s.add_graph(g) # # all_print = tf.get_collection('print') an, final_states, len = dec_out stacked = decoder.attention_score(dec_out) with tf.Session() as sess: sess.run([ tf.global_variables_initializer(), tf.local_variables_initializer(), tf.tables_initializer() ]) sess.run(iter.initializer) print(sess.run([eval_dec_out]))
def editor_train(base_words, extended_base_words, output_words, extended_output_words, source_words, target_words, insert_words, delete_words, oov, vocab_size, hidden_dim, agenda_dim, edit_dim, micro_edit_ev_dim, num_heads, num_encoder_layers, num_decoder_layers, attn_dim, beam_width, ctx_hidden_dim, ctx_hidden_layer, wa_hidden_dim, wa_hidden_layer, meve_hidden_dim, meve_hidden_layers, max_sent_length, dropout_keep, lamb_reg, norm_eps, norm_max, kill_edit, draw_edit, swap_memory, use_beam_decoder=False, use_dropout=False, no_insert_delete_attn=False, enable_vae=True): # [batch] base_len = seq.length_pre_embedding(base_words) output_len = seq.length_pre_embedding(extended_output_words) # variable of shape [vocab_size, embed_dim] embeddings = vocab.get_embeddings() # [batch x max_len x embed_dim] base_word_embeds = vocab.embed_tokens(base_words) # [batch x max_len x rnn_out_dim], [batch x rnn_out_dim] base_sent_hidden_states, base_sent_embed = encoder.source_sent_encoder( base_word_embeds, base_len, hidden_dim, num_encoder_layers, use_dropout=use_dropout, dropout_keep=dropout_keep, swap_memory=swap_memory) assert kill_edit == False and draw_edit == False # [batch x agenda_dim] base_agenda = linear(base_sent_embed, agenda_dim) train_dec_inp, train_dec_inp_len, \ train_dec_out, train_dec_out_len = prepare_decoder_input_output(output_words, extended_output_words, output_len) train_dec_inp_extended = prepare_decoder_inputs(extended_output_words, tf.cast(-1, tf.int64)) train_decoder = decoder.train_decoder( base_agenda, embeddings, extended_base_words, oov, train_dec_inp, train_dec_inp_extended, base_sent_hidden_states, train_dec_inp_len, base_len, vocab_size, attn_dim, hidden_dim, num_decoder_layers, swap_memory, enable_dropout=use_dropout, dropout_keep=dropout_keep, no_insert_delete_attn=no_insert_delete_attn) if use_beam_decoder: infr_decoder = decoder.beam_eval_decoder( base_agenda, embeddings, extended_base_words, oov, vocab.get_token_id(vocab.START_TOKEN), vocab.get_token_id(vocab.STOP_TOKEN), base_sent_hidden_states, base_len, vocab_size, attn_dim, hidden_dim, num_decoder_layers, max_sent_length, beam_width, swap_memory, enable_dropout=use_dropout, dropout_keep=dropout_keep, no_insert_delete_attn=no_insert_delete_attn) else: infr_decoder = decoder.greedy_eval_decoder( base_agenda, embeddings, extended_base_words, oov, vocab.get_token_id(vocab.START_TOKEN), vocab.get_token_id(vocab.STOP_TOKEN), base_sent_hidden_states, base_len, vocab_size, attn_dim, hidden_dim, num_decoder_layers, max_sent_length, swap_memory, enable_dropout=use_dropout, dropout_keep=dropout_keep, no_insert_delete_attn=no_insert_delete_attn) add_decoder_attn_history_graph(infr_decoder) return train_decoder, infr_decoder, train_dec_out, train_dec_out_len
def decoder_outputs_to_edit_vector(decoder_output, temperature_starter, decay_rate, decay_steps, edit_dim, enc_hidden_dim, enc_num_layers, dense_layers, swap_memory): with tf.variable_scope(OPS_NAME): # [VOCAB x word_dim] embeddings = vocab.get_embeddings() # Extend embedding matrix to support oov tokens unk_id = vocab.get_token_id(vocab.UNKNOWN_TOKEN) unk_embed = tf.expand_dims(vocab.embed_tokens(unk_id), 0) unk_embeddings = tf.tile(unk_embed, [50, 1]) # [VOCAB+50 x word_dim] embeddings_extended = tf.concat([embeddings, unk_embeddings], axis=0) global_step = tf.train.get_global_step() temperature = tf.train.exponential_decay(temperature_starter, global_step, decay_steps, decay_rate, name='temperature') tf.summary.scalar('temper', temperature, ['extra']) # [batch x max_len x VOCAB+50], softmax probabilities outputs = decoder.rnn_output(decoder_output) # substitute values less than 0 for numerical stability outputs = tf.where(tf.less_equal(outputs, 0), tf.ones_like(outputs) * 1e-10, outputs) # convert softmax probabilities to one_hot vectors dist = tfd.RelaxedOneHotCategorical(temperature, probs=outputs) # [batch x max_len x VOCAB+50], one_hot outputs_one_hot = dist.sample() # [batch x max_len x word_dim], one_hot^T * embedding_matrix outputs_embed = tf.einsum("btv,vd-> btd", outputs_one_hot, embeddings_extended) # [batch] outputs_length = decoder.seq_length(decoder_output) # [batch x max_len x hidden], [batch x hidden] hidden_states, sentence_embedding = encoder.source_sent_encoder( outputs_embed, outputs_length, enc_hidden_dim, enc_num_layers, use_dropout=False, dropout_keep=1.0, swap_memory=swap_memory) h = sentence_embedding for l in dense_layers: h = tf.layers.dense(h, l, activation='relu', name='hidden_%s' % (l)) # [batch x edit_dim] edit_vector = tf.layers.dense(h, edit_dim, activation=None, name='linear') return edit_vector