def prepare_decoder_input_output(tgt_words, tgt_len, vocab_table): """ Args: tgt_words: tensor of word ids, [batch x max_len] tgt_len: vector of sentence lengths, [batch] vocab_table: instance of tf.vocab_lookup_table Returns: dec_input: tensor of word ids, [batch x max_len+1] dec_input_len: vector of sentence lengths, [batch] dec_output: tensor of word ids, [batch x max_len+1] dec_output_len: vector of sentence lengths, [batch] """ start_token_id = vocab.get_token_id(vocab.START_TOKEN, vocab_table) stop_token_id = vocab.get_token_id(vocab.STOP_TOKEN, vocab_table) pad_token_id = vocab.get_token_id(vocab.PAD_TOKEN, vocab_table) dec_input = decoder.prepare_decoder_inputs(tgt_words, start_token_id) dec_input_len = seq.length_pre_embedding(dec_input) dec_output = decoder.prepare_decoder_output(tgt_words, tgt_len, stop_token_id, pad_token_id) dec_output_len = seq.length_pre_embedding(dec_output) return dec_input, dec_input_len, dec_output, dec_output_len
def get_trace_summary(vocab_i2s, pred_tokens, tgt_tokens, src_words, inserted_words, deleted_words, pred_len, tgt_len): if pred_tokens.shape.ndims > 2: pred_joined = metrics.join_beams(pred_tokens, pred_len) else: pred_joined = metrics.join_tokens(pred_tokens, pred_len) tgt_joined = metrics.join_tokens(tgt_tokens, tgt_len) src_joined = metrics.join_tokens(vocab_i2s.lookup(src_words), length_pre_embedding(src_words)) iw_joined = metrics.join_tokens(vocab_i2s.lookup(inserted_words), length_pre_embedding(inserted_words), ', ') dw_joined = metrics.join_tokens(vocab_i2s.lookup(deleted_words), length_pre_embedding(deleted_words), ', ') return tf.concat([src_joined, iw_joined, dw_joined, tgt_joined, pred_joined], axis=1)
def encode_all(self, base_word_ids, source_word_ids, target_word_ids, insert_word_ids, common_word_ids): batch_size = tf.shape(base_word_ids)[0] with tf.name_scope('encode_all'): base_len = seq.length_pre_embedding(base_word_ids) src_len = seq.length_pre_embedding(source_word_ids) tgt_len = seq.length_pre_embedding(target_word_ids) iw_len = seq.length_pre_embedding(insert_word_ids) cw_len = seq.length_pre_embedding(common_word_ids) base_encoded, base_attention_bias = self.encoder( base_word_ids, base_len) kill_edit = self.config.editor.kill_edit draw_edit = self.config.editor.draw_edit if self.config.editor.decoder.allow_mev_st_attn \ or self.config.editor.decoder.allow_mev_ts_attn: assert kill_edit == False and draw_edit == False if kill_edit: edit_vector = tf.zeros( shape=(batch_size, self.config.editor.edit_encoder.edit_dim)) mev_st = mev_ts = None else: if draw_edit: edit_vector = random_noise_encoder( batch_size, self.config.editor.edit_encoder.edit_dim, self.config.editor.norm_max) mev_st = mev_ts = None else: edit_vector, mev_st, mev_ts = self.edit_encoder( source_word_ids, target_word_ids, insert_word_ids, common_word_ids, src_len, tgt_len, iw_len, cw_len, ) encoder_outputs = (base_encoded, base_attention_bias) edit_encoder_outputs = (edit_vector, mev_st, mev_ts) return encoder_outputs, edit_encoder_outputs
def _prepare_inputs(self, output_word_ids: tf.Tensor, edit_vector: tf.Tensor): # Add start token to decoder inputs decoder_input_words = prepare_decoder_input( output_word_ids) # [batch, output_len+1] decoder_input_max_len = tf.shape(decoder_input_words)[1] decoder_input_len = sequence.length_pre_embedding( decoder_input_words) # [batch] # Get word embeddings decoder_input_embeds = self.embedding_layer( decoder_input_words) # [batch, output_len+1, hidden_size) # Add positional encoding to the embeddings part with tf.name_scope('positional_encoding'): pos_encoding = model_utils.get_position_encoding( decoder_input_max_len, self.config.orig_hidden_size) decoder_input_embeds += pos_encoding decoder_input = decoder_input_embeds if self.config.enable_dropout and self.config.layer_postprocess_dropout > 0.: decoder_input = tf.nn.dropout( decoder_input, 1 - self.config.layer_postprocess_dropout) return decoder_input, decoder_input_len
def test_wtf(): with tf.Graph().as_default(): V, embed_matrix = vocab.read_word_embeddings( Path('../data') / 'word_vectors' / 'glove.6B.300d_yelp.txt', 300, 10000 ) table = vocab.create_vocab_lookup_tables(V) vocab_s2i = table[vocab.STR_TO_INT] vocab_i2s = table[vocab.INT_TO_STR] dataset = input_fn('../data/yelp_dataset_large_split/train.tsv', table, 64, 1) iter = dataset.make_initializable_iterator() (src, tgt, iw, dw), _ = iter.get_next() src_len = length_pre_embedding(src) tgt_len = length_pre_embedding(tgt) iw_len = length_pre_embedding(iw) dw_len = length_pre_embedding(dw) dec_inputs = decoder.prepare_decoder_inputs(tgt, vocab.get_token_id(vocab.START_TOKEN, vocab_s2i)) dec_output = decoder.prepare_decoder_output(tgt, tgt_len, vocab.get_token_id(vocab.STOP_TOKEN, vocab_s2i), vocab.get_token_id(vocab.PAD_TOKEN, vocab_s2i)) t_src = vocab_i2s.lookup(src) t_tgt = vocab_i2s.lookup(tgt) t_iw = vocab_i2s.lookup(iw) t_dw = vocab_i2s.lookup(dw) t_do = vocab_i2s.lookup(dec_output) t_di = vocab_i2s.lookup(dec_inputs) with tf.Session() as sess: sess.run([tf.global_variables_initializer(), tf.local_variables_initializer(), tf.tables_initializer()]) sess.run(iter.initializer) while True: try: # src, tgt, iw, dw = sess.run([src, tgt, iw, dw]) ts, tt, tiw, tdw, tdo, tdi = sess.run([t_src, t_tgt, t_iw, t_dw, t_do, t_di]) except: break
def test_rnn_encoder(dataset_file, embedding_file): with tf.Graph().as_default(): d_fn, gold_dataset = dataset_file e_fn, gold_embeds = embedding_file v, embed_matrix = vocab.read_word_embeddings(e_fn, EMBED_DIM) vocab_lookup = vocab.get_vocab_lookup(v) dataset = neural_editor.input_fn(d_fn, vocab_lookup, BATCH_SIZE, NUM_EPOCH) embedding = tf.get_variable( 'embeddings', shape=embed_matrix.shape, initializer=tf.constant_initializer(embed_matrix)) iter = dataset.make_initializable_iterator() (src, tgt, iw, dw), _ = iter.get_next() EDIT_DIM = 8 output = ev.rnn_encoder(tf.nn.embedding_lookup(embedding, src), tf.nn.embedding_lookup(embedding, tgt), tf.nn.embedding_lookup(embedding, iw), tf.nn.embedding_lookup(embedding, dw), sequence.length_pre_embedding(src), sequence.length_pre_embedding(tgt), sequence.length_pre_embedding(iw), sequence.length_pre_embedding(dw), 256, 2, 256, 1, EDIT_DIM, 100.0, 0.1, 14.0, 0.8) with tf.Session() as sess: sess.run([ tf.global_variables_initializer(), tf.local_variables_initializer(), tf.tables_initializer() ]) sess.run(iter.initializer) while True: try: oeo = sess.run(output) assert oeo.shape == (BATCH_SIZE, EDIT_DIM) except: break
def add_extra_summary_avg_bleu(vocab_i2s, decoder_output, ref_words, collections=None): hypo_tokens = decoder.str_tokens(decoder_output, vocab_i2s) hypo_len = decoder.seq_length(decoder_output) ref_tokens = vocab_i2s.lookup(ref_words) ref_len = length_pre_embedding(ref_words) avg_bleu = get_avg_bleu_smmary(ref_tokens, hypo_tokens, ref_len, hypo_len) tf.summary.scalar('bleu', avg_bleu, collections) return avg_bleu
def add_extra_summary_avg_bleu(hypo_tokens, hypo_len, ref_words, vocab_i2s, collections=None): ref_tokens = vocab_i2s.lookup(ref_words) ref_len = length_pre_embedding(ref_words) avg_bleu = get_avg_bleu(ref_tokens, hypo_tokens, ref_len, hypo_len) tf.summary.scalar('bleu', avg_bleu, collections) return avg_bleu
def calculate_loss(logits, output_words, input_words, tgt_words, label_smoothing, vocab_size): gold = decoder.prepare_decoder_output( output_words, sequence.length_pre_embedding(output_words)) gold_len = sequence.length_pre_embedding(gold) gold_input = decoder.prepare_decoder_output( input_words, sequence.length_pre_embedding(input_words)) gold_input_len = sequence.length_pre_embedding(gold_input) gold_tgt = decoder.prepare_decoder_output( tgt_words, sequence.length_pre_embedding(tgt_words)) gold_tgt_len = sequence.length_pre_embedding(gold_tgt) main_loss, _ = optimizer.padded_cross_entropy_loss(logits, gold, gold_len, label_smoothing, vocab_size) input_loss, _ = optimizer.padded_cross_entropy_loss( logits, gold_input, gold_input_len, label_smoothing, vocab_size) tgt_loss, _ = optimizer.padded_cross_entropy_loss(logits, gold_tgt, gold_tgt_len, label_smoothing, vocab_size) total_loss = main_loss - 1. / 50 * input_loss - 1. / 30 * tgt_loss return total_loss
def add_extra_summary_avg_bleu(vocab_i2s, decoder_output, tgt_words, collections=None): pred_tokens = decoder.str_tokens(decoder_output, vocab_i2s) pred_len = decoder.seq_length(decoder_output) tgt_tokens = vocab_i2s.lookup(tgt_words) tgt_len = length_pre_embedding(tgt_words) avg_bleu = get_avg_bleu_smmary(tgt_tokens, pred_tokens, tgt_len, pred_len) tf.summary.scalar('bleu', avg_bleu, collections) return avg_bleu
def add_extra_summary_trace(vocab_i2s, decoder_output, base_words, output_words, src_words, tgt_words, inserted_words, deleted_words, collections=None): pred_tokens = decoder.str_tokens(decoder_output, vocab_i2s) pred_len = decoder.seq_length(decoder_output) tgt_tokens = vocab_i2s.lookup(tgt_words) tgt_len = length_pre_embedding(tgt_words) trace_summary = get_trace_summary(vocab_i2s, pred_tokens, tgt_tokens, src_words, inserted_words, deleted_words, pred_len, tgt_len) tf.summary.text('trace', trace_summary, collections) return trace_summary
def get_logits(self, encoded_inputs, output_word_ids): with tf.name_scope('logits'): encoder_outputs, edit_encoder_outputs = encoded_inputs base_sent_hidden_states, base_sent_attention_bias = encoder_outputs edit_vector, mev_st, mev_ts = edit_encoder_outputs output_len = seq.length_pre_embedding(output_word_ids) logits = self.decoder(output_word_ids, output_len, base_sent_hidden_states, base_sent_attention_bias, edit_vector, mev_st, mev_ts, mode='train') return logits
def add_extra_summary_trace(pred_tokens, pred_len, base_words, output_words, src_words, tgt_words, inserted_words, deleted_words, collections=None): vocab_i2s = vocab.get_vocab_lookup_tables()[vocab.INT_TO_STR] tgt_tokens = vocab_i2s.lookup(tgt_words) tgt_len = length_pre_embedding(tgt_words) trace_summary = get_trace(pred_tokens, tgt_tokens, src_words, inserted_words, deleted_words, pred_len, tgt_len) tf.summary.text('trace', trace_summary, collections) return trace_summary
def test_length_pre_embedding(): def generate_sequence(seq_len): seq = [] for i in range(MAX_LEN): if i < seq_len: seq.append(random.randint(4, 1000)) else: seq.append(0) return seq gold_seq_lengths = [random.randint(4, MAX_LEN) for _ in range(BATCH_SIZE - 1)] + [MAX_LEN] sequence_batch = [generate_sequence(l) for l in gold_seq_lengths] batch = np.array(sequence_batch, dtype=np.float32) tf.enable_eager_execution() lengths = length_pre_embedding(batch) assert lengths.shape == (BATCH_SIZE,) assert list(lengths.numpy()) == gold_seq_lengths
def test_context_encoder(dataset_file, embedding_file): with tf.Graph().as_default(): d_fn, gold_dataset = dataset_file e_fn, gold_embeds = embedding_file v, embed_matrix = vocab.read_word_embeddings(e_fn, EMBED_DIM) vocab_lookup = vocab.get_vocab_lookup(v) dataset = neural_editor.input_fn(d_fn, vocab_lookup, BATCH_SIZE, NUM_EPOCH) embedding = tf.get_variable( 'embeddings', shape=embed_matrix.shape, initializer=tf.constant_initializer(embed_matrix)) iter = dataset.make_initializable_iterator() (_, _, src, _), _ = iter.get_next() src_len = sequence.length_pre_embedding(src) src_embd = tf.nn.embedding_lookup(embedding, src) output = ev.context_encoder(src_embd, src_len, HIDDEN_DIM, NUM_LAYER) with tf.Session() as sess: sess.run([ tf.global_variables_initializer(), tf.local_variables_initializer(), tf.tables_initializer() ]) sess.run(iter.initializer) while True: try: oeo, o_src, o_src_len, o_src_embd = sess.run( [output, src, src_len, src_embd]) assert oeo.shape == (BATCH_SIZE, o_src_len.max(), HIDDEN_DIM) except: break
def test_encoder(dataset_file, embedding_file): d_fn, gold_dataset = dataset_file e_fn, gold_embeds = embedding_file v, embed_matrix = vocab.read_word_embeddings(e_fn, EMBED_DIM) vocab_lookup = vocab.get_vocab_lookup(v) dataset = neural_editor.input_fn(d_fn, vocab_lookup, BATCH_SIZE, NUM_EPOCH) embedding = tf.get_variable( 'embeddings', shape=embed_matrix.shape, initializer=tf.constant_initializer(embed_matrix)) iter = dataset.make_initializable_iterator() (src, _, _, _), _ = iter.get_next() src_len = sequence.length_pre_embedding(src) src_embd = tf.nn.embedding_lookup(embedding, src) encoder_output, _ = encoder.bidirectional_encoder(src_embd, src_len, HIDDEN_DIM, NUM_LAYER, 0.9) with tf.Session() as sess: sess.run([ tf.global_variables_initializer(), tf.local_variables_initializer(), tf.tables_initializer() ]) sess.run(iter.initializer) oeo, o_src, o_src_len, o_src_embd = sess.run( [encoder_output, src, src_len, src_embd]) for i in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES): print(i) assert oeo.shape == (BATCH_SIZE, o_src_len.max(), HIDDEN_DIM)
def editor_train(base_words, source_words, target_words, insert_words, delete_words, embed_matrix, vocab_table, hidden_dim, agenda_dim, edit_dim, num_encoder_layers, num_decoder_layers, attn_dim, beam_width, max_sent_length, dropout_keep, lamb_reg, norm_eps, norm_max, kill_edit, draw_edit, swap_memory, use_beam_decoder=False, use_dropout=False, no_insert_delete_attn=False, enable_vae=True): batch_size = tf.shape(source_words)[0] # [batch] base_len = seq.length_pre_embedding(base_words) src_len = seq.length_pre_embedding(source_words) tgt_len = seq.length_pre_embedding(target_words) iw_len = seq.length_pre_embedding(insert_words) dw_len = seq.length_pre_embedding(delete_words) # variable of shape [vocab_size, embed_dim] embeddings = vocab.get_embeddings() # [batch x max_len x embed_dim] base_word_embeds = vocab.embed_tokens(base_words) # src_word_embeds = vocab.embed_tokens(source_words) # tgt_word_embeds = vocab.embed_tokens(target_words) insert_word_embeds = vocab.embed_tokens(insert_words) delete_word_embeds = vocab.embed_tokens(delete_words) # [batch x max_len x rnn_out_dim], [batch x rnn_out_dim] base_sent_hidden_states, base_sent_embed = encoder.source_sent_encoder( base_word_embeds, base_len, hidden_dim, num_encoder_layers, use_dropout=use_dropout, dropout_keep=dropout_keep, swap_memory=swap_memory) # [batch x edit_dim] if kill_edit: edit_vector = tf.zeros(shape=(batch_size, edit_dim)) else: if draw_edit: edit_vector = edit_encoder.random_noise_encoder( batch_size, edit_dim, norm_max) else: edit_vector = edit_encoder.accumulator_encoder( insert_word_embeds, delete_word_embeds, iw_len, dw_len, edit_dim, lamb_reg, norm_eps, norm_max, dropout_keep, enable_vae=enable_vae) # [batch x agenda_dim] input_agenda = agn.linear(base_sent_embed, edit_vector, agenda_dim) train_dec_inp, train_dec_inp_len, \ train_dec_out, train_dec_out_len = prepare_decoder_input_output(target_words, tgt_len, vocab_table) train_decoder = decoder.train_decoder( input_agenda, embeddings, train_dec_inp, base_sent_hidden_states, insert_word_embeds, delete_word_embeds, train_dec_inp_len, base_len, iw_len, dw_len, attn_dim, hidden_dim, num_decoder_layers, swap_memory, enable_dropout=use_dropout, dropout_keep=dropout_keep, no_insert_delete_attn=no_insert_delete_attn) if use_beam_decoder: infr_decoder = decoder.beam_eval_decoder( input_agenda, embeddings, vocab.get_token_id(vocab.START_TOKEN, vocab_table), vocab.get_token_id(vocab.STOP_TOKEN, vocab_table), base_sent_hidden_states, insert_word_embeds, delete_word_embeds, base_len, iw_len, dw_len, attn_dim, hidden_dim, num_decoder_layers, max_sent_length, beam_width, swap_memory, enable_dropout=use_dropout, dropout_keep=dropout_keep, no_insert_delete_attn=no_insert_delete_attn) else: infr_decoder = decoder.greedy_eval_decoder( input_agenda, embeddings, vocab.get_token_id(vocab.START_TOKEN, vocab_table), vocab.get_token_id(vocab.STOP_TOKEN, vocab_table), base_sent_hidden_states, insert_word_embeds, delete_word_embeds, base_len, iw_len, dw_len, attn_dim, hidden_dim, num_decoder_layers, max_sent_length, swap_memory, enable_dropout=use_dropout, dropout_keep=dropout_keep, no_insert_delete_attn=no_insert_delete_attn) return train_decoder, infr_decoder, train_dec_out, train_dec_out_len
def editor_train(base_words, extended_base_words, output_words, extended_output_words, source_words, target_words, insert_words, delete_words, oov, vocab_size, hidden_dim, agenda_dim, edit_dim, micro_edit_ev_dim, num_heads, num_encoder_layers, num_decoder_layers, attn_dim, beam_width, ctx_hidden_dim, ctx_hidden_layer, wa_hidden_dim, wa_hidden_layer, meve_hidden_dim, meve_hidden_layers, max_sent_length, dropout_keep, lamb_reg, norm_eps, norm_max, kill_edit, draw_edit, swap_memory, use_beam_decoder=False, use_dropout=False, no_insert_delete_attn=False, enable_vae=True): # [batch] base_len = seq.length_pre_embedding(base_words) output_len = seq.length_pre_embedding(extended_output_words) # variable of shape [vocab_size, embed_dim] embeddings = vocab.get_embeddings() # [batch x max_len x embed_dim] base_word_embeds = vocab.embed_tokens(base_words) # [batch x max_len x rnn_out_dim], [batch x rnn_out_dim] base_sent_hidden_states, base_sent_embed = encoder.source_sent_encoder( base_word_embeds, base_len, hidden_dim, num_encoder_layers, use_dropout=use_dropout, dropout_keep=dropout_keep, swap_memory=swap_memory) assert kill_edit == False and draw_edit == False # [batch x agenda_dim] base_agenda = linear(base_sent_embed, agenda_dim) train_dec_inp, train_dec_inp_len, \ train_dec_out, train_dec_out_len = prepare_decoder_input_output(output_words, extended_output_words, output_len) train_dec_inp_extended = prepare_decoder_inputs(extended_output_words, tf.cast(-1, tf.int64)) train_decoder = decoder.train_decoder( base_agenda, embeddings, extended_base_words, oov, train_dec_inp, train_dec_inp_extended, base_sent_hidden_states, train_dec_inp_len, base_len, vocab_size, attn_dim, hidden_dim, num_decoder_layers, swap_memory, enable_dropout=use_dropout, dropout_keep=dropout_keep, no_insert_delete_attn=no_insert_delete_attn) if use_beam_decoder: infr_decoder = decoder.beam_eval_decoder( base_agenda, embeddings, extended_base_words, oov, vocab.get_token_id(vocab.START_TOKEN), vocab.get_token_id(vocab.STOP_TOKEN), base_sent_hidden_states, base_len, vocab_size, attn_dim, hidden_dim, num_decoder_layers, max_sent_length, beam_width, swap_memory, enable_dropout=use_dropout, dropout_keep=dropout_keep, no_insert_delete_attn=no_insert_delete_attn) else: infr_decoder = decoder.greedy_eval_decoder( base_agenda, embeddings, extended_base_words, oov, vocab.get_token_id(vocab.START_TOKEN), vocab.get_token_id(vocab.STOP_TOKEN), base_sent_hidden_states, base_len, vocab_size, attn_dim, hidden_dim, num_decoder_layers, max_sent_length, swap_memory, enable_dropout=use_dropout, dropout_keep=dropout_keep, no_insert_delete_attn=no_insert_delete_attn) add_decoder_attn_history_graph(infr_decoder) return train_decoder, infr_decoder, train_dec_out, train_dec_out_len
def editor_train(base_words, extended_base_words, output_words, extended_output_words, source_words, target_words, insert_words, delete_words, oov, vocab_size, hidden_dim, agenda_dim, edit_dim, micro_edit_ev_dim, num_heads, num_encoder_layers, num_decoder_layers, attn_dim, beam_width, ctx_hidden_dim, ctx_hidden_layer, wa_hidden_dim, wa_hidden_layer, meve_hidden_dim, meve_hidden_layers, recons_dense_layers, max_sent_length, dropout_keep, lamb_reg, norm_eps, norm_max, kill_edit, draw_edit, swap_memory, use_beam_decoder=False, use_dropout=False, no_insert_delete_attn=False, enable_vae=True): batch_size = tf.shape(source_words)[0] # [batch] base_len = seq.length_pre_embedding(base_words) output_len = seq.length_pre_embedding(extended_output_words) src_len = seq.length_pre_embedding(source_words) tgt_len = seq.length_pre_embedding(target_words) iw_len = seq.length_pre_embedding(insert_words) dw_len = seq.length_pre_embedding(delete_words) # variable of shape [vocab_size, embed_dim] embeddings = vocab.get_embeddings() # [batch x max_len x embed_dim] base_word_embeds = vocab.embed_tokens(base_words) output_word_embeds = vocab.embed_tokens(output_words) # src_word_embeds = vocab.embed_tokens(source_words) # tgt_word_embeds = vocab.embed_tokens(target_words) insert_word_embeds = vocab.embed_tokens(insert_words) delete_word_embeds = vocab.embed_tokens(delete_words) # [batch x max_len x rnn_out_dim], [batch x rnn_out_dim] base_sent_hidden_states, base_sent_embed = encoder.source_sent_encoder( base_word_embeds, base_len, hidden_dim, num_encoder_layers, use_dropout=use_dropout, dropout_keep=dropout_keep, swap_memory=swap_memory) assert kill_edit == False and draw_edit == False # [batch x edit_dim] if kill_edit: edit_vector = tf.zeros(shape=(batch_size, edit_dim)) else: if draw_edit: edit_vector = random_noise_encoder(batch_size, edit_dim, norm_max) else: edit_vector = accumulator_encoder(insert_word_embeds, delete_word_embeds, iw_len, dw_len, edit_dim, lamb_reg, norm_eps, norm_max, dropout_keep, enable_vae=enable_vae) wa_inserted, wa_deleted = (tf.constant([[0]]), tf.constant( [[0]])), (tf.constant([[0]]), tf.constant([[0]])) # [batch x agenda_dim] base_agenda = agn.linear(base_sent_embed, edit_vector, agenda_dim) train_dec_inp, train_dec_inp_len, \ train_dec_out, train_dec_out_len = prepare_decoder_input_output(output_words, extended_output_words, output_len) train_dec_inp_extended = prepare_decoder_inputs(extended_output_words, tf.cast(-1, tf.int64)) train_decoder = decoder.train_decoder( base_agenda, embeddings, extended_base_words, oov, train_dec_inp, train_dec_inp_extended, base_sent_hidden_states, wa_inserted, wa_deleted, train_dec_inp_len, base_len, src_len, tgt_len, vocab_size, attn_dim, hidden_dim, num_decoder_layers, swap_memory, enable_dropout=use_dropout, dropout_keep=dropout_keep, no_insert_delete_attn=no_insert_delete_attn) if use_beam_decoder: infr_decoder = decoder.beam_eval_decoder( base_agenda, embeddings, extended_base_words, oov, vocab.get_token_id(vocab.START_TOKEN), vocab.get_token_id(vocab.STOP_TOKEN), base_sent_hidden_states, wa_inserted, wa_deleted, base_len, src_len, tgt_len, vocab_size, attn_dim, hidden_dim, num_decoder_layers, max_sent_length, beam_width, swap_memory, enable_dropout=use_dropout, dropout_keep=dropout_keep, no_insert_delete_attn=no_insert_delete_attn) else: infr_decoder = decoder.greedy_eval_decoder( base_agenda, embeddings, extended_base_words, oov, vocab.get_token_id(vocab.START_TOKEN), vocab.get_token_id(vocab.STOP_TOKEN), base_sent_hidden_states, wa_inserted, wa_deleted, base_len, src_len, tgt_len, vocab_size, attn_dim, hidden_dim, num_decoder_layers, max_sent_length, swap_memory, enable_dropout=use_dropout, dropout_keep=dropout_keep, no_insert_delete_attn=no_insert_delete_attn) edit_vector_recons = output_words_to_edit_vector( output_word_embeds, output_len, edit_dim, ctx_hidden_dim, ctx_hidden_layer, recons_dense_layers, swap_memory) optimizer.add_reconst_loss(edit_vector, edit_vector_recons) return train_decoder, infr_decoder, train_dec_out, train_dec_out_len
def model_fn(features, mode, config, embedding_matrix, vocab_tables): if mode == tf.estimator.ModeKeys.PREDICT: base_words, _, src_words, tgt_words, inserted_words, commong_words = features output_words = tgt_words else: base_words, output_words, src_words, tgt_words, inserted_words, commong_words = features is_training = mode == tf.estimator.ModeKeys.TRAIN tf.add_to_collection('is_training', is_training) if mode != tf.estimator.ModeKeys.TRAIN: config.put('editor.enable_dropout', False) config.put('editor.dropout_keep', 1.0) config.put('editor.dropout', 0.0) config.put('editor.transformer.enable_dropout', False) config.put('editor.transformer.layer_postprocess_dropout', 0.0) config.put('editor.transformer.attention_dropout', 0.0) config.put('editor.transformer.relu_dropout', 0.0) vocab.init_embeddings(embedding_matrix) EmbeddingSharedWeights.init_from_embedding_matrix() editor_model = Editor(config) logits, beam_prediction = editor_model(base_words, src_words, tgt_words, inserted_words, commong_words, output_words) targets = decoder.prepare_decoder_output( output_words, sequence.length_pre_embedding(output_words)) target_lengths = sequence.length_pre_embedding(targets) vocab_size = embedding_matrix.shape[0] loss, weights = optimizer.padded_cross_entropy_loss( logits, targets, target_lengths, config.optim.label_smoothing, vocab_size) train_op = optimizer.get_train_op(loss, config) tf.logging.info("Trainable variable") for i in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES): tf.logging.info(str(i)) tf.logging.info("Num of Trainable parameters") tf.logging.info( np.sum([ np.prod(v.get_shape().as_list()) for v in tf.trainable_variables() ])) if mode == tf.estimator.ModeKeys.TRAIN: decoded_ids = decoder.logits_to_decoded_ids(logits) ops = add_extra_summary(config, decoded_ids, target_lengths, base_words, output_words, src_words, tgt_words, inserted_words, commong_words, collections=['extra']) hooks = [ get_train_extra_summary_writer(config), get_extra_summary_logger(ops, config), ] if config.get('logger.enable_profiler', False): hooks.append(get_profiler_hook(config)) return tf.estimator.EstimatorSpec(mode, train_op=train_op, loss=loss, training_hooks=hooks) elif mode == tf.estimator.ModeKeys.EVAL: decoded_ids = decoder.logits_to_decoded_ids(logits) ops = add_extra_summary(config, decoded_ids, target_lengths, base_words, output_words, src_words, tgt_words, inserted_words, commong_words, collections=['extra']) return tf.estimator.EstimatorSpec( mode, loss=loss, evaluation_hooks=[get_extra_summary_logger(ops, config)], eval_metric_ops={'bleu': tf_metrics.streaming_mean(ops[ES_BLEU])}) elif mode == tf.estimator.ModeKeys.PREDICT: decoded_ids, decoded_lengths, scores = beam_prediction tokens = decoder.str_tokens(decoded_ids) preds = { 'str_tokens': tf.transpose(tokens, [0, 2, 1]), 'decoded_ids': tf.transpose(decoded_ids, [0, 2, 1]), 'lengths': decoded_lengths, 'joined': metrics.join_tokens(tokens, decoded_lengths) } tmee_attentions = tf.get_collection( 'TransformerMicroEditExtractor_Attentions') if len(tmee_attentions) > 0: preds.update({ 'tmee_attentions_st_enc_self': tmee_attentions[0][0], 'tmee_attentions_st_dec_self': tmee_attentions[0][1], 'tmee_attentions_st_dec_enc': tmee_attentions[0][2], 'tmee_attentions_ts_enc_self': tmee_attentions[1][0], 'tmee_attentions_ts_dec_self': tmee_attentions[1][1], 'tmee_attentions_ts_dec_enc': tmee_attentions[1][2], 'src_words': src_words, 'tgt_words': tgt_words, 'base_words': base_words, 'output_words': output_words }) return tf.estimator.EstimatorSpec(mode, predictions=preds)
def editor_train(base_words, output_words, source_words, target_words, insert_words, delete_words, hidden_dim, agenda_dim, edit_dim, micro_edit_ev_dim, num_heads, num_encoder_layers, num_decoder_layers, attn_dim, beam_width, ctx_hidden_dim, ctx_hidden_layer, wa_hidden_dim, wa_hidden_layer, meve_hidden_dim, meve_hidden_layers, max_sent_length, dropout_keep, lamb_reg, norm_eps, norm_max, kill_edit, draw_edit, swap_memory, use_beam_decoder=False, use_dropout=False, no_insert_delete_attn=False, enable_vae=True): batch_size = tf.shape(source_words)[0] # [batch] base_len = seq.length_pre_embedding(base_words) output_len = seq.length_pre_embedding(output_words) src_len = seq.length_pre_embedding(source_words) tgt_len = seq.length_pre_embedding(target_words) iw_len = seq.length_pre_embedding(insert_words) dw_len = seq.length_pre_embedding(delete_words) # variable of shape [vocab_size, embed_dim] embeddings = vocab.get_embeddings() # [batch x max_len x embed_dim] base_word_embeds = vocab.embed_tokens(base_words) src_word_embeds = vocab.embed_tokens(source_words) tgt_word_embeds = vocab.embed_tokens(target_words) insert_word_embeds = vocab.embed_tokens(insert_words) delete_word_embeds = vocab.embed_tokens(delete_words) sent_encoder = tf.make_template('sent_encoder', encoder.source_sent_encoder, hidden_dim=hidden_dim, num_layer=num_encoder_layers, swap_memory=swap_memory, use_dropout=use_dropout, dropout_keep=dropout_keep) # [batch x max_len x rnn_out_dim], [batch x rnn_out_dim] base_sent_hidden_states, base_sent_embed = sent_encoder( base_word_embeds, base_len) assert kill_edit == False and draw_edit == False # [batch x edit_dim] if kill_edit: edit_vector = tf.zeros(shape=(batch_size, edit_dim)) else: if draw_edit: edit_vector = random_noise_encoder(batch_size, edit_dim, norm_max) else: edit_vector, wa_inserted, wa_deleted = attn_encoder( src_word_embeds, tgt_word_embeds, insert_word_embeds, delete_word_embeds, src_len, tgt_len, iw_len, dw_len, ctx_hidden_dim, ctx_hidden_layer, wa_hidden_dim, wa_hidden_layer, meve_hidden_dim, meve_hidden_layers, edit_dim, micro_edit_ev_dim, num_heads, lamb_reg, norm_eps, norm_max, sent_encoder, use_dropout=use_dropout, dropout_keep=dropout_keep, swap_memory=swap_memory, enable_vae=enable_vae) # [batch x agenda_dim] input_agenda = agn.linear(base_sent_embed, edit_vector, agenda_dim) train_dec_inp, train_dec_inp_len, \ train_dec_out, train_dec_out_len = prepare_decoder_input_output(output_words, output_len, None) train_decoder = decoder.train_decoder( input_agenda, embeddings, train_dec_inp, base_sent_hidden_states, wa_inserted, wa_deleted, train_dec_inp_len, base_len, iw_len, dw_len, attn_dim, hidden_dim, num_decoder_layers, swap_memory, enable_dropout=use_dropout, dropout_keep=dropout_keep, no_insert_delete_attn=no_insert_delete_attn) if use_beam_decoder: infr_decoder = decoder.beam_eval_decoder( input_agenda, embeddings, vocab.get_token_id(vocab.START_TOKEN), vocab.get_token_id(vocab.STOP_TOKEN), base_sent_hidden_states, wa_inserted, wa_deleted, base_len, iw_len, dw_len, attn_dim, hidden_dim, num_decoder_layers, max_sent_length, beam_width, swap_memory, enable_dropout=use_dropout, dropout_keep=dropout_keep, no_insert_delete_attn=no_insert_delete_attn) else: infr_decoder = decoder.greedy_eval_decoder( input_agenda, embeddings, vocab.get_token_id(vocab.START_TOKEN), vocab.get_token_id(vocab.STOP_TOKEN), base_sent_hidden_states, wa_inserted, wa_deleted, base_len, iw_len, dw_len, attn_dim, hidden_dim, num_decoder_layers, max_sent_length, swap_memory, enable_dropout=use_dropout, dropout_keep=dropout_keep, no_insert_delete_attn=no_insert_delete_attn) return train_decoder, infr_decoder, train_dec_out, train_dec_out_len
def test_decoder_prepares(dataset_file, embedding_file): with tf.Graph().as_default(): d_fn, gold_dataset = dataset_file e_fn, gold_embeds = embedding_file v, embed_matrix = vocab.read_word_embeddings(e_fn, EMBED_DIM) vocab_lookup = vocab.get_vocab_lookup(v) stop_token = tf.constant(bytes(vocab.STOP_TOKEN, encoding='utf8'), dtype=tf.string) stop_token_id = vocab_lookup.lookup(stop_token) start_token = tf.constant(bytes(vocab.START_TOKEN, encoding='utf8'), dtype=tf.string) start_token_id = vocab_lookup.lookup(start_token) pad_token = tf.constant(bytes(vocab.PAD_TOKEN, encoding='utf8'), dtype=tf.string) pad_token_id = vocab_lookup.lookup(pad_token) dataset = neural_editor.input_fn(d_fn, vocab_lookup, BATCH_SIZE, NUM_EPOCH) iter = dataset.make_initializable_iterator() (_, tgt, _, _), _ = iter.get_next() tgt_len = sequence.length_pre_embedding(tgt) dec_inputs = decoder.prepare_decoder_inputs(tgt, start_token_id) dec_outputs = decoder.prepare_decoder_output(tgt, tgt_len, stop_token_id, pad_token_id) dec_inputs_len = sequence.length_pre_embedding(dec_inputs) dec_outputs_len = sequence.length_pre_embedding(dec_outputs) dec_outputs_last = sequence.last_relevant( tf.expand_dims(dec_outputs, 2), dec_outputs_len) dec_outputs_last = tf.squeeze(dec_outputs_last) with tf.Session() as sess: sess.run([ tf.global_variables_initializer(), tf.local_variables_initializer(), tf.tables_initializer() ]) sess.run(iter.initializer) while True: try: dec_inputs, dec_outputs, tgt_len, dil, dol, start_token_id, stop_token_id, dec_outputs_last, tgt = sess.run( [ dec_inputs, dec_outputs, tgt_len, dec_inputs_len, dec_outputs_len, start_token_id, stop_token_id, dec_outputs_last, tgt ]) assert list(dil) == list(dol) == list(tgt_len + 1) assert list(dec_inputs[:, 0]) == list( np.ones_like(dec_inputs[:, 0]) * start_token_id) assert list(dec_outputs_last) == list( np.ones_like(dec_outputs_last) * stop_token_id) except: break
def test_decoder_train(dataset_file, embedding_file): with tf.Graph().as_default(): d_fn, gold_dataset = dataset_file e_fn, gold_embeds = embedding_file v, embed_matrix = vocab.read_word_embeddings(e_fn, EMBED_DIM) vocab_lookup = vocab.get_vocab_lookup(v) stop_token = tf.constant(bytes(vocab.STOP_TOKEN, encoding='utf8'), dtype=tf.string) stop_token_id = vocab_lookup.lookup(stop_token) start_token = tf.constant(bytes(vocab.START_TOKEN, encoding='utf8'), dtype=tf.string) start_token_id = vocab_lookup.lookup(start_token) pad_token = tf.constant(bytes(vocab.PAD_TOKEN, encoding='utf8'), dtype=tf.string) pad_token_id = vocab_lookup.lookup(pad_token) dataset = neural_editor.input_fn(d_fn, vocab_lookup, BATCH_SIZE, NUM_EPOCH) iter = dataset.make_initializable_iterator() (src, tgt, inw, dlw), _ = iter.get_next() src_len = sequence.length_pre_embedding(src) tgt_len = sequence.length_pre_embedding(tgt) dec_inputs = decoder.prepare_decoder_inputs(tgt, start_token_id) dec_outputs = decoder.prepare_decoder_output(tgt, tgt_len, stop_token_id, pad_token_id) dec_inputs_len = sequence.length_pre_embedding(dec_inputs) dec_outputs_len = sequence.length_pre_embedding(dec_outputs) batch_size = tf.shape(src)[0] edit_vector = edit_encoder.random_noise_encoder( batch_size, EDIT_DIM, 14.0) embedding = tf.get_variable( 'embeddings', shape=embed_matrix.shape, initializer=tf.constant_initializer(embed_matrix)) src_embd = tf.nn.embedding_lookup(embedding, src) src_sent_embeds, final_states = encoder.source_sent_encoder( src_embd, src_len, 20, 3, 0.8) agn = agenda.linear(final_states, edit_vector, 4) dec_out = decoder.train_decoder(agn, embedding, dec_inputs, src_sent_embeds, tf.nn.embedding_lookup(embedding, inw), tf.nn.embedding_lookup(embedding, dlw), dec_inputs_len, src_len, sequence.length_pre_embedding(inw), sequence.length_pre_embedding(dlw), 5, 20, 3, False) # eval_dec_out = decoder.greedy_eval_decoder( # agn, embedding, # start_token_id, stop_token_id, # src_sent_embeds, # tf.nn.embedding_lookup(embedding, inw), # tf.nn.embedding_lookup(embedding, dlw), # src_len, sequence.length_pre_embedding(inw), sequence.length_pre_embedding(dlw), # 5, 20, 3, 40 # ) eval_dec_out = decoder.beam_eval_decoder( agn, embedding, start_token_id, stop_token_id, src_sent_embeds, tf.nn.embedding_lookup(embedding, inw), tf.nn.embedding_lookup(embedding, dlw), src_len, sequence.length_pre_embedding(inw), sequence.length_pre_embedding(dlw), 5, 20, 3, 40) # saver = tf.train.Saver(write_version=tf.train.SaverDef.V1) # s = tf.summary.FileWriter('data/an') # s.add_graph(g) # # all_print = tf.get_collection('print') an, final_states, len = dec_out stacked = decoder.attention_score(dec_out) with tf.Session() as sess: sess.run([ tf.global_variables_initializer(), tf.local_variables_initializer(), tf.tables_initializer() ]) sess.run(iter.initializer) print(sess.run([eval_dec_out]))