def prepare_decoder_input_output(tgt_words, tgt_len, vocab_table): """ Args: tgt_words: tensor of word ids, [batch x max_len] tgt_len: vector of sentence lengths, [batch] vocab_table: instance of tf.vocab_lookup_table Returns: dec_input: tensor of word ids, [batch x max_len+1] dec_input_len: vector of sentence lengths, [batch] dec_output: tensor of word ids, [batch x max_len+1] dec_output_len: vector of sentence lengths, [batch] """ start_token_id = vocab.get_token_id(vocab.START_TOKEN, vocab_table) stop_token_id = vocab.get_token_id(vocab.STOP_TOKEN, vocab_table) pad_token_id = vocab.get_token_id(vocab.PAD_TOKEN, vocab_table) dec_input = decoder.prepare_decoder_inputs(tgt_words, start_token_id) dec_input_len = seq.length_pre_embedding(dec_input) dec_output = decoder.prepare_decoder_output(tgt_words, tgt_len, stop_token_id, pad_token_id) dec_output_len = seq.length_pre_embedding(dec_output) return dec_input, dec_input_len, dec_output, dec_output_len
def test_wtf(): with tf.Graph().as_default(): V, embed_matrix = vocab.read_word_embeddings( Path('../data') / 'word_vectors' / 'glove.6B.300d_yelp.txt', 300, 10000 ) table = vocab.create_vocab_lookup_tables(V) vocab_s2i = table[vocab.STR_TO_INT] vocab_i2s = table[vocab.INT_TO_STR] dataset = input_fn('../data/yelp_dataset_large_split/train.tsv', table, 64, 1) iter = dataset.make_initializable_iterator() (src, tgt, iw, dw), _ = iter.get_next() src_len = length_pre_embedding(src) tgt_len = length_pre_embedding(tgt) iw_len = length_pre_embedding(iw) dw_len = length_pre_embedding(dw) dec_inputs = decoder.prepare_decoder_inputs(tgt, vocab.get_token_id(vocab.START_TOKEN, vocab_s2i)) dec_output = decoder.prepare_decoder_output(tgt, tgt_len, vocab.get_token_id(vocab.STOP_TOKEN, vocab_s2i), vocab.get_token_id(vocab.PAD_TOKEN, vocab_s2i)) t_src = vocab_i2s.lookup(src) t_tgt = vocab_i2s.lookup(tgt) t_iw = vocab_i2s.lookup(iw) t_dw = vocab_i2s.lookup(dw) t_do = vocab_i2s.lookup(dec_output) t_di = vocab_i2s.lookup(dec_inputs) with tf.Session() as sess: sess.run([tf.global_variables_initializer(), tf.local_variables_initializer(), tf.tables_initializer()]) sess.run(iter.initializer) while True: try: # src, tgt, iw, dw = sess.run([src, tgt, iw, dw]) ts, tt, tiw, tdw, tdo, tdi = sess.run([t_src, t_tgt, t_iw, t_dw, t_do, t_di]) except: break
def test_decoder_train(dataset_file, embedding_file): with tf.Graph().as_default(): d_fn, gold_dataset = dataset_file e_fn, gold_embeds = embedding_file v, embed_matrix = vocab.read_word_embeddings(e_fn, EMBED_DIM) vocab_lookup = vocab.get_vocab_lookup(v) stop_token = tf.constant(bytes(vocab.STOP_TOKEN, encoding='utf8'), dtype=tf.string) stop_token_id = vocab_lookup.lookup(stop_token) start_token = tf.constant(bytes(vocab.START_TOKEN, encoding='utf8'), dtype=tf.string) start_token_id = vocab_lookup.lookup(start_token) pad_token = tf.constant(bytes(vocab.PAD_TOKEN, encoding='utf8'), dtype=tf.string) pad_token_id = vocab_lookup.lookup(pad_token) dataset = neural_editor.input_fn(d_fn, vocab_lookup, BATCH_SIZE, NUM_EPOCH) iter = dataset.make_initializable_iterator() (src, tgt, inw, dlw), _ = iter.get_next() src_len = sequence.length_pre_embedding(src) tgt_len = sequence.length_pre_embedding(tgt) dec_inputs = decoder.prepare_decoder_inputs(tgt, start_token_id) dec_outputs = decoder.prepare_decoder_output(tgt, tgt_len, stop_token_id, pad_token_id) dec_inputs_len = sequence.length_pre_embedding(dec_inputs) dec_outputs_len = sequence.length_pre_embedding(dec_outputs) batch_size = tf.shape(src)[0] edit_vector = edit_encoder.random_noise_encoder( batch_size, EDIT_DIM, 14.0) embedding = tf.get_variable( 'embeddings', shape=embed_matrix.shape, initializer=tf.constant_initializer(embed_matrix)) src_embd = tf.nn.embedding_lookup(embedding, src) src_sent_embeds, final_states = encoder.source_sent_encoder( src_embd, src_len, 20, 3, 0.8) agn = agenda.linear(final_states, edit_vector, 4) dec_out = decoder.train_decoder(agn, embedding, dec_inputs, src_sent_embeds, tf.nn.embedding_lookup(embedding, inw), tf.nn.embedding_lookup(embedding, dlw), dec_inputs_len, src_len, sequence.length_pre_embedding(inw), sequence.length_pre_embedding(dlw), 5, 20, 3, False) # eval_dec_out = decoder.greedy_eval_decoder( # agn, embedding, # start_token_id, stop_token_id, # src_sent_embeds, # tf.nn.embedding_lookup(embedding, inw), # tf.nn.embedding_lookup(embedding, dlw), # src_len, sequence.length_pre_embedding(inw), sequence.length_pre_embedding(dlw), # 5, 20, 3, 40 # ) eval_dec_out = decoder.beam_eval_decoder( agn, embedding, start_token_id, stop_token_id, src_sent_embeds, tf.nn.embedding_lookup(embedding, inw), tf.nn.embedding_lookup(embedding, dlw), src_len, sequence.length_pre_embedding(inw), sequence.length_pre_embedding(dlw), 5, 20, 3, 40) # saver = tf.train.Saver(write_version=tf.train.SaverDef.V1) # s = tf.summary.FileWriter('data/an') # s.add_graph(g) # # all_print = tf.get_collection('print') an, final_states, len = dec_out stacked = decoder.attention_score(dec_out) with tf.Session() as sess: sess.run([ tf.global_variables_initializer(), tf.local_variables_initializer(), tf.tables_initializer() ]) sess.run(iter.initializer) print(sess.run([eval_dec_out]))
def test_decoder_prepares(dataset_file, embedding_file): with tf.Graph().as_default(): d_fn, gold_dataset = dataset_file e_fn, gold_embeds = embedding_file v, embed_matrix = vocab.read_word_embeddings(e_fn, EMBED_DIM) vocab_lookup = vocab.get_vocab_lookup(v) stop_token = tf.constant(bytes(vocab.STOP_TOKEN, encoding='utf8'), dtype=tf.string) stop_token_id = vocab_lookup.lookup(stop_token) start_token = tf.constant(bytes(vocab.START_TOKEN, encoding='utf8'), dtype=tf.string) start_token_id = vocab_lookup.lookup(start_token) pad_token = tf.constant(bytes(vocab.PAD_TOKEN, encoding='utf8'), dtype=tf.string) pad_token_id = vocab_lookup.lookup(pad_token) dataset = neural_editor.input_fn(d_fn, vocab_lookup, BATCH_SIZE, NUM_EPOCH) iter = dataset.make_initializable_iterator() (_, tgt, _, _), _ = iter.get_next() tgt_len = sequence.length_pre_embedding(tgt) dec_inputs = decoder.prepare_decoder_inputs(tgt, start_token_id) dec_outputs = decoder.prepare_decoder_output(tgt, tgt_len, stop_token_id, pad_token_id) dec_inputs_len = sequence.length_pre_embedding(dec_inputs) dec_outputs_len = sequence.length_pre_embedding(dec_outputs) dec_outputs_last = sequence.last_relevant( tf.expand_dims(dec_outputs, 2), dec_outputs_len) dec_outputs_last = tf.squeeze(dec_outputs_last) with tf.Session() as sess: sess.run([ tf.global_variables_initializer(), tf.local_variables_initializer(), tf.tables_initializer() ]) sess.run(iter.initializer) while True: try: dec_inputs, dec_outputs, tgt_len, dil, dol, start_token_id, stop_token_id, dec_outputs_last, tgt = sess.run( [ dec_inputs, dec_outputs, tgt_len, dec_inputs_len, dec_outputs_len, start_token_id, stop_token_id, dec_outputs_last, tgt ]) assert list(dil) == list(dol) == list(tgt_len + 1) assert list(dec_inputs[:, 0]) == list( np.ones_like(dec_inputs[:, 0]) * start_token_id) assert list(dec_outputs_last) == list( np.ones_like(dec_outputs_last) * stop_token_id) except: break