def decode_test_set(encoder_state, decoder_cell, decoder_embeddings_matrix, sos_id, eos_id, maximum_length, num_words, decoding_scope, output_function, keep_prob, batch_size): attention_states = tf.zeros([batch_size, 1, decoder_cell.output_size]) attention_keys, attention_values, attention_score_function, attention_construct_function \ = seq2seq.prepare_attention(attention_states, attention_option="bahdanau", num_units=decoder_cell.output_size) test_decoder_function = seq2seq.attention_decoder_fn_inference( output_function, encoder_state[0], attention_keys, attention_values, attention_score_function, attention_construct_function, decoder_embeddings_matrix, sos_id, eos_id, maximum_length, num_words, name="attn_dec_inf") test_predictions, _, _ = seq2seq.dynamic_rnn_decoder(decoder_cell, test_decoder_function, scope=decoding_scope) return test_predictions
def decode_topk(self): with self.graph.as_default(): with tf.variable_scope("Decoder") as scope: tf.get_variable_scope().reuse_variables() def output_fn(outputs): return tf.contrib.layers.linear(outputs, self.vocab_size, scope=scope) decoder_fn_inference = seq2seq.attention_decoder_fn_inference( output_fn=output_fn, encoder_state=self.encoder_final_state, attention_keys=self.attention_keys, attention_values=self.attention_values, attention_score_fn=self.attention_score_fn, attention_construct_fn=self.attention_construct_fn, embeddings=self.embed, start_of_sequence_id=self.EOS, end_of_sequence_id=self.EOS, maximum_length=23, # max_twee_len + 3, num_decoder_symbols=self.vocab_size) (self.decoder_logits_inference_beam, self.decoder_state_inference_beam, self.decoder_context_state_inference_beam) = ( seq2seq.dynamic_rnn_decoder( cell=self.decoder_cell, decoder_fn=decoder_fn_inference, time_major=True, scope=scope)) return self.decoder_logits_inference_beam, self.decoder_state_inference_beam
def TweetInitDecoder(self, input_state): with self.graph.as_default(): with tf.variable_scope("TweetInitDecoder") as scope: def output_fn(outputs): return tf.contrib.layers.linear(outputs, self.vocab_size, scope=scope) decoder_fn_inference = seq2seq.attention_decoder_fn_inference( output_fn=output_fn, encoder_state=input_state, # self.encoder_final_state attention_keys=self.attention_keys, attention_values=self.attention_values, attention_score_fn=self.attention_score_fn, attention_construct_fn=self.attention_construct_fn, embeddings=self.embed, start_of_sequence_id=self.EOS, end_of_sequence_id=self.EOS, maximum_length=23, # max_twee_len + 3, num_decoder_symbols=self.vocab_size) (self.tidecoder_logits_inference, self.tidecoder_state_inference, self.tidecoder_context_state_inference) = ( seq2seq.dynamic_rnn_decoder( cell=self.decoder_cell, decoder_fn=decoder_fn_inference, time_major=True, scope=scope)) self.tidecoder_prediction_inference = tf.argmax( self.tidecoder_logits_inference, axis=-1, name='TIdecoder_prediction_inference') return self.tidecoder_prediction_inference
def decode_validation_set(encoder_state, decoder_cell, decoder_embeddings_matrix, sos_id, eos_id, max_length, num_words, decoding_scope, output_function, keep_prob, batch_size): attention_states = tf.zeros([batch_size, 1, decoder_cell.output_size]) attention_keys, attention_values, attention_score_fx, attention_construct_fx = prepare_attention( attention_states, attention_option='bahdanau', num_units=decoder_cell.output_size) validate_decoder_fx = attention_decoder_fn_inference( output_fn=output_function, encoder_state=encoder_state[0], attention_keys=attention_keys, attention_values=attention_values, attention_score_fn=attention_score_fx, attention_construct_fn=attention_construct_fx, embeddings=decoder_embeddings_matrix, start_of_sequence_id=sos_id, end_of_sequence_id=eos_id, maximum_length=max_length, num_decoder_symbols=num_words, name='attn_dec_inf') predictions, _, _ = dynamic_rnn_decoder(cell=decoder_cell, decoder_fn=validate_decoder_fx, scope=decoding_scope) return predictions
def decoder_test_set(encoder_state, decoder_cell, batch_size, decoder_scope, keep_prob, decoder_embedding_matrix, sequence_length, decoder_output_function, sos_id, eos_id, max_length, num_symbols): attention_state = tf.zeros([batch_size, 1, decoder_cell.output_size]) attention_key, attention_value, attention_score_function, attention_construct_function = seq2seq.prepare_attention( attention_states=attention_state, attention_option="nabdanau", num_units=decoder_cell.output_size) decoder_test_output = seq2seq.attention_decoder_fn_inference( output_fn=decoder_output_function, encoder_state=encoder_state, attention_keys=attention_key, attention_values=attention_value, attention_score_fn=attention_score_function, attention_construct_fn=attention_construct_function, embeddings=decoder_embedding_matrix, start_of_sequence_id=sos_id, end_of_sequence_id=eos_id, maximum_length=max_length, num_decoder_symbols=num_symbols, dtype=tf.float32, name="attn_dec_inf") decoder_output, _, _ = seq2seq.dynamic_rnn_decoder(decoder_cell, decoder_test_output, scope=decoder_scope) return decoder_output
def decode_test_set(encoder_state, decoder_cell, decoder_embeddings_matrix, sos_id, eos_id, maximum_length, num_words, sequence_length, decoding_scope, output_function, keep_prob, batch_size): """Decode the validation set.""" attention_states = tf.zeros([batch_size, 1, decoder_cell.output_size]) (attention_keys, attentions_values, attention_score_function, attention_construct_function) = prepare_attention( attention_states, attention_option='bahdanau', num_units=decoder_cell.output_size) test_decoder_function = attention_decoder_fn_inference( output_function, encoder_state[0], attention_keys, attentions_values, attention_score_function, attention_construct_function, decoder_embeddings_matrix, sos_id, eos_id, maximum_length, num_words, name='attn_dec_inf') test_predictions, _, _ = dynamic_rnn_decoder(decoder_cell, test_decoder_function, scope=decoding_scope) return test_predictions
def addDecoder(self): print('adding decoder...') cell = BasicRNNCell(2 * CONFIG.DIM_WordEmbedding) self.attention_states = self._encoder_outputs self.decoder_inputs_embedded = tf.nn.embedding_lookup( self.embedding, self.y_placeholder) # prepare attention: (attention_keys, attention_values, attention_score_fn, attention_construct_fn) = seq2seq.prepare_attention( attention_states=self.attention_states, attention_option='bahdanau', num_units=2 * CONFIG.DIM_WordEmbedding) if (self.is_training): # new Seq2seq train version self.check_op = tf.add_check_numerics_ops() decoder_fn_train = seq2seq.attention_decoder_fn_train( encoder_state=self._decoder_in_state, attention_keys=attention_keys, attention_values=attention_values, attention_score_fn=attention_score_fn, attention_construct_fn=attention_construct_fn, name='attention_decoder') (self.decoder_outputs_train, self.decoder_state_train, self.decoder_context_state_train) = seq2seq.dynamic_rnn_decoder( cell=cell, decoder_fn=decoder_fn_train, inputs=self.decoder_inputs_embedded, sequence_length=self.y_lens, time_major=False) self.decoder_outputs = self.decoder_outputs_train else: # new Seq2seq version start_id = CONFIG.WORDS[CONFIG.STARTWORD] stop_id = CONFIG.WORDS[CONFIG.STOPWORD] decoder_fn_inference = seq2seq.attention_decoder_fn_inference( encoder_state=self._decoder_in_state, attention_keys=attention_keys, attention_values=attention_values, attention_score_fn=attention_score_fn, attention_construct_fn=attention_construct_fn, embeddings=self.embedding, start_of_sequence_id=start_id, end_of_sequence_id=stop_id, maximum_length=CONFIG.DIM_DECODER, num_decoder_symbols=CONFIG.DIM_VOCAB, output_fn=self.output_fn) (self.decoder_outputs_inference, self.decoder_state_inference, self.decoder_context_state_inference ) = seq2seq.dynamic_rnn_decoder(cell=cell, decoder_fn=decoder_fn_inference, time_major=False) self.decoder_outputs = self.decoder_outputs_inference
def decoder(self,encoder_state,attention_states,inputs=None,is_train=True): ''' 基于attention的解码器 1.调用seq2seq.prepare_attention 生成attention的keys/values/functions 2.训练时,定义dynamic_rnn_decoder用到的attention_decoder_fn_train 3.预测时,定义dynamic_rnn_decoder用到的attention_decoder_fn_inference 4.使用以上步骤得到的参数,调用seq2seq.dynamic_rnn_decoder函数 ''' with tf.variable_scope("decoder") as scope: #1. prepare attention keys,values,score_fn,construct_fn=seq2seq.prepare_attention(attention_states=attention_states,attention_option="luong",num_units=self.emb_dim) if is_train is True: decoder_fn=seq2seq.attention_decoder_fn_train(encoder_state,attention_keys=keys,attention_values=values,attention_score_fn=score_fn,attention_construct_fn=construct_fn,) outputs,final_state,final_context_state=seq2seq.dynamic_rnn_decoder(self.decoder_cell,decoder_fn=decoder_fn,inputs=inputs,sequence_length=self.seq_len,time_major=False,scope=scope) else: tf.get_variable_scope().reuse_variables() #解码时,通过decoder embedding和decoder bias计算每个词的概率 output_fn=lambda x:tf.nn.softmax(tf.matmul(x,self.dec_embedding,transpose_b=True)+self.dec_bias) decoder_fn=seq2seq.attention_decoder_fn_inference(output_fn=output_fn,encoder_state=encoder_state,attention_keys=keys,attention_values=values, attention_score_fn=score_fn,attention_construct_fn=construct_fn,embeddings=self.dec_embedding, start_of_sequence_id=0,end_of_sequence_id=1,maximum_length=5,num_decoder_symbols=self.vocab_size) outputs,final_state,final_context_state=seq2seq.dynamic_rnn_decoder(self.decoder_cell,decoder_fn=decoder_fn,inputs=None,sequence_length=self.seq_len,time_major=False,scope=scope) return outputs,final_state,final_context_state
def _init_decoder(self): with tf.variable_scope("Decoder") as scope: def output_fn(outputs): self.test_outputs = outputs return tf.contrib.layers.linear(outputs, self.decoder_vocab_size, scope=scope) if not self.attention: decoder_fn_train = seq2seq.simple_decoder_fn_train(encoder_state=self.encoder_state) decoder_fn_inference = seq2seq.simple_decoder_fn_inference( output_fn=output_fn, encoder_state=self.encoder_state, embeddings=self.decoder_embedding_matrix, start_of_sequence_id=self.EOS, end_of_sequence_id=self.EOS, maximum_length=tf.reduce_max(self.encoder_inputs_length) + 100, num_decoder_symbols=self.decoder_vocab_size, ) else: # attention_states: size [batch_size, max_time, num_units] attention_states = tf.transpose(self.encoder_outputs, [1, 0, 2]) (attention_keys, attention_values, attention_score_fn, attention_construct_fn) = seq2seq.prepare_attention( attention_states=attention_states, attention_option="bahdanau", num_units=self.decoder_hidden_units, ) decoder_fn_train = seq2seq.attention_decoder_fn_train( encoder_state=self.encoder_state, attention_keys=attention_keys, attention_values=attention_values, attention_score_fn=attention_score_fn, attention_construct_fn=attention_construct_fn, name='attention_decoder' ) decoder_fn_inference = seq2seq.attention_decoder_fn_inference( output_fn=output_fn, encoder_state=self.encoder_state, attention_keys=attention_keys, attention_values=attention_values, attention_score_fn=attention_score_fn, attention_construct_fn=attention_construct_fn, embeddings=self.decoder_embedding_matrix, start_of_sequence_id=self.EOS, end_of_sequence_id=self.EOS, maximum_length=tf.reduce_max(self.encoder_inputs_length) + 100, num_decoder_symbols=self.decoder_vocab_size, ) (self.decoder_outputs_train, self.decoder_state_train, self.decoder_context_state_train) = ( seq2seq.dynamic_rnn_decoder( cell=self.decoder_cell, decoder_fn=decoder_fn_train, inputs=self.decoder_train_inputs_embedded, sequence_length=self.decoder_train_length, time_major=self.time_major, scope=scope, ) ) self.decoder_logits_train = output_fn(self.decoder_outputs_train) self.decoder_prediction_train = tf.argmax(self.decoder_logits_train, axis=-1, name='decoder_prediction_train') scope.reuse_variables() (self.decoder_logits_inference, self.decoder_state_inference, self.decoder_context_state_inference) = ( seq2seq.dynamic_rnn_decoder( cell=self.decoder_cell, decoder_fn=decoder_fn_inference, time_major=self.time_major, scope=scope, ) ) self.decoder_prediction_inference = tf.argmax(self.decoder_logits_inference, axis=-1, name='decoder_prediction_inference')
def __init_decoder(self): '''Initializes the decoder part of the model.''' with tf.variable_scope('decoder') as scope: output_fn = lambda outs: layers.linear( outs, self.__get_vocab_size(), scope=scope) if self.cfg.get('use_attention'): attention_states = tf.transpose(self.encoder_outputs, [1, 0, 2]) (attention_keys, attention_values, attention_score_fn, attention_construct_fn) = seq2seq.prepare_attention( attention_states=attention_states, attention_option='bahdanau', num_units=self.decoder_cell.output_size) decoder_fn_train = seq2seq.attention_decoder_fn_train( encoder_state=self.encoder_state, attention_keys=attention_keys, attention_values=attention_values, attention_score_fn=attention_score_fn, attention_construct_fn=attention_construct_fn, name='attention_decoder') decoder_fn_inference = seq2seq.attention_decoder_fn_inference( output_fn=output_fn, encoder_state=self.encoder_state, attention_keys=attention_keys, attention_values=attention_values, attention_score_fn=attention_score_fn, attention_construct_fn=attention_construct_fn, embeddings=self.embeddings, start_of_sequence_id=Config.EOS_WORD_IDX, end_of_sequence_id=Config.EOS_WORD_IDX, maximum_length=tf.reduce_max(self.encoder_inputs_length) + 3, num_decoder_symbols=self.__get_vocab_size()) else: decoder_fn_train = seq2seq.simple_decoder_fn_train( encoder_state=self.encoder_state) decoder_fn_inference = seq2seq.simple_decoder_fn_inference( output_fn=output_fn, encoder_state=self.encoder_state, embeddings=self.embeddings, start_of_sequence_id=Config.EOS_WORD_IDX, end_of_sequence_id=Config.EOS_WORD_IDX, maximum_length=tf.reduce_max(self.encoder_inputs_length) + 3, num_decoder_symbols=self.__get_vocab_size()) (self.decoder_outputs_train, self.decoder_state_train, self.decoder_context_state_train) = seq2seq.dynamic_rnn_decoder( cell=self.decoder_cell, decoder_fn=decoder_fn_train, inputs=self.decoder_train_inputs_embedded, sequence_length=self.decoder_train_length, time_major=True, scope=scope) self.decoder_logits_train = output_fn(self.decoder_outputs_train) self.decoder_prediction_train = tf.argmax( self.decoder_logits_train, axis=-1, name='decoder_prediction_traion') scope.reuse_variables() (self.decoder_logits_inference, decoder_state_inference, self.decoder_context_state_inference ) = seq2seq.dynamic_rnn_decoder(cell=self.decoder_cell, decoder_fn=decoder_fn_inference, time_major=True, scope=scope) self.decoder_prediction_inference = tf.argmax( self.decoder_logits_inference, axis=-1, name='decoder_prediction_inference')
def __init__(self, para): para.fac = int(para.bidirectional) + 1 self._para = para if para.rnn_type == 0: #basic rnn def unit_cell(fac): return tf.contrib.rnn.BasicRNNCell(para.hidden_size * fac) elif para.rnn_type == 1: #basic LSTM def unit_cell(fac): return tf.contrib.rnn.BasicLSTMCell(para.hidden_size * fac) elif para.rnn_type == 2: #full LSTM def unit_cell(fac): return tf.contrib.rnn.LSTMCell(para.hidden_size * fac, use_peepholes=True) elif para.rnn_type == 3: #GRU def unit_cell(fac): return tf.contrib.rnn.GRUCell(para.hidden_size * fac) rnn_cell = unit_cell #dropout layer if not self.is_test() and para.keep_prob < 1: def rnn_cell(fac): return tf.contrib.rnn.DropoutWrapper( unit_cell(fac), output_keep_prob=para.keep_prob) #multi-layer rnn encoder_cell =\ tf.contrib.rnn.MultiRNNCell([rnn_cell(1) for _ in range(para.layer_num)]) if para.bidirectional: b_encoder_cell = tf.contrib.rnn.MultiRNNCell( [rnn_cell(1) for _ in range(para.layer_num)]) #feed in data in batches if not self.is_test(): video, caption, v_len, c_len = self.get_single_example(para) videos, captions, v_lens, c_lens =\ tf.train.batch([video, caption, v_len, c_len], batch_size=para.batch_size, dynamic_pad=True) #sparse tensor cannot be sliced targets = tf.sparse_tensor_to_dense(captions) decoder_in = targets[:, :-1] decoder_out = targets[:, 1:] c_lens = tf.to_int32(c_lens) else: video, v_len = self.get_single_example(para) videos, v_lens =\ tf.train.batch([video, v_len], batch_size=para.batch_size, dynamic_pad=True) v_lens = tf.to_int32(v_lens) with tf.variable_scope('embedding'): if para.use_pretrained: W_E =\ tf.Variable(tf.constant(0., shape= [para.vocab_size, para.w_emb_dim]), trainable=False, name='W_E') self._embedding = tf.placeholder( tf.float32, [para.vocab_size, para.w_emb_dim]) self._embed_init = W_E.assign(self._embedding) else: W_E = tf.get_variable('W_E', [para.vocab_size, para.w_emb_dim], dtype=tf.float32) if not self.is_test(): decoder_in_embed = tf.nn.embedding_lookup(W_E, decoder_in) if para.v_emb_dim < para.video_dim: inputs = fully_connected(videos, para.v_emb_dim) else: inputs = videos if not self.is_test() and para.keep_prob < 1: inputs = tf.nn.dropout(inputs, para.keep_prob) if not para.bidirectional: encoder_outputs, encoder_states =\ tf.nn.dynamic_rnn(encoder_cell, inputs, sequence_length=v_lens, dtype=tf.float32) else: encoder_outputs, encoder_states =\ tf.nn.bidirectional_dynamic_rnn(encoder_cell, b_encoder_cell, inputs, sequence_length=v_lens, dtype=tf.float32) encoder_states = tuple([ LSTMStateTuple(tf.concat([f_st.c, f_st.c], 1), tf.concat([b_st.h, b_st.h], 1)) for f_st, b_st in zip(encoder_states[0], encoder_states[1]) ]) encoder_outputs = tf.concat( [encoder_outputs[0], encoder_outputs[1]], 2) with tf.variable_scope('softmax'): softmax_w = tf.get_variable( 'w', [para.hidden_size * para.fac, para.vocab_size], dtype=tf.float32) softmax_b = tf.get_variable('b', [para.vocab_size], dtype=tf.float32) output_fn = lambda output: tf.nn.xw_plus_b(output, softmax_w, softmax_b) decoder_cell =\ tf.contrib.rnn.MultiRNNCell([rnn_cell(para.fac) for _ in range(para.layer_num)]) if para.attention > 0: at_option = ["bahdanau", "luong"][para.attention - 1] at_keys, at_vals, at_score, at_cons =\ seq2seq.prepare_attention(attention_states=encoder_outputs, attention_option=at_option, num_units=para.hidden_size*para.fac) if self.is_test(): if para.attention: decoder_fn_inference = seq2seq.attention_decoder_fn_inference( output_fn=output_fn, encoder_state=encoder_states, attention_keys=at_keys, attention_values=at_vals, attention_score_fn=at_score, attention_construct_fn=at_cons, embeddings=W_E, start_of_sequence_id=2, end_of_sequence_id=3, maximum_length=20, num_decoder_symbols=para.vocab_size) else: decoder_fn_inference = seq2seq.simple_decoder_fn_inference( output_fn=output_fn, encoder_state=encoder_states, embeddings=W_E, start_of_sequence_id=2, end_of_sequence_id=3, maximum_length=20, num_decoder_symbols=para.vocab_size) with tf.variable_scope('decode', reuse=True): decoder_logits, _, _ =\ seq2seq.dynamic_rnn_decoder(cell=decoder_cell, decoder_fn=decoder_fn_inference) self._prob = tf.nn.softmax(decoder_logits) else: global_step = tf.contrib.framework.get_or_create_global_step() def decoder_fn_train(time, cell_state, cell_input, cell_output, context): if para.scheduled_sampling and cell_output is not None: epsilon = tf.cast( 1 - (global_step // (para.tot_train_num // para.batch_size + 1) / para.max_epoch), tf.float32) cell_input = tf.cond( tf.less(tf.random_uniform([1]), epsilon)[0], lambda: cell_input, lambda: tf.gather( W_E, tf.argmax(output_fn(cell_output), 1))) if cell_state is None: cell_state = encoder_states if para.attention: attention = _init_attention(encoder_states) else: if para.attention: cell_output = attention = at_cons( cell_output, at_keys, at_vals) if para.attention: nxt_cell_input = tf.concat([cell_input, attention], 1) else: nxt_cell_input = cell_input return None, encoder_states, nxt_cell_input, cell_output, context with tf.variable_scope('decode', reuse=None): (decoder_outputs, _, _) =\ seq2seq.dynamic_rnn_decoder(cell=decoder_cell, decoder_fn=decoder_fn_train, inputs=decoder_in_embed, sequence_length=c_lens) decoder_outputs =\ tf.reshape(decoder_outputs, [-1, para.hidden_size*para.fac]) c_len_max = tf.reduce_max(c_lens) logits = output_fn(decoder_outputs) logits = tf.reshape(logits, [para.batch_size, c_len_max, para.vocab_size]) self._prob = tf.nn.softmax(logits) msk = tf.sequence_mask(c_lens, dtype=tf.float32) loss = sequence_loss(logits, decoder_out, msk) self._cost = cost = tf.reduce_mean(loss) #if validation or testing, exit here if self.is_valid(): return #clip global gradient norm tvars = tf.trainable_variables() grads, _ = tf.clip_by_global_norm(tf.gradients(cost, tvars), para.max_grad_norm) optimizer = optimizers[para.optimizer](para.learning_rate) self._eval = optimizer.apply_gradients(zip(grads, tvars), global_step=global_step)
def _init_decoder(self): with tf.variable_scope("Decoder") as scope: def output_fn(outputs): return tf.contrib.layers.linear(outputs, self.vocab_size, scope=scope) if not self.attention: decoder_fn_train = seq2seq.simple_decoder_fn_train( encoder_state=self.encoder_state) decoder_fn_inference = seq2seq.simple_decoder_fn_inference( output_fn=output_fn, encoder_state=self.encoder_state, embeddings=self.embedding_matrix, start_of_sequence_id=data_utils.GO_ID, end_of_sequence_id=data_utils.EOS_ID, maximum_length=FLAGS.max_inf_target_len, num_decoder_symbols=self.vocab_size, ) else: # attention_states: size [batch_size, max_time, num_units] attention_states = tf.transpose(self.encoder_outputs, [1, 0, 2]) #attention_states = tf.zeros([batch_size, 1, self.decoder_hidden_units]) (attention_keys, attention_values, attention_score_fn, attention_construct_fn) = seq2seq.prepare_attention( attention_states=attention_states, attention_option="bahdanau", num_units=self.decoder_hidden_units, ) decoder_fn_train = seq2seq.attention_decoder_fn_train( encoder_state=self.encoder_state, attention_keys=attention_keys, attention_values=attention_values, attention_score_fn=attention_score_fn, attention_construct_fn=attention_construct_fn, name='attention_decoder') decoder_fn_inference = seq2seq.attention_decoder_fn_inference( output_fn=output_fn, encoder_state=self.encoder_state, attention_keys=attention_keys, attention_values=attention_values, attention_score_fn=attention_score_fn, attention_construct_fn=attention_construct_fn, embeddings=self.embedding_matrix, start_of_sequence_id=data_utils.GO_ID, end_of_sequence_id=data_utils.EOS_ID, maximum_length=FLAGS.max_inf_target_len, num_decoder_symbols=self.vocab_size, ) (self.decoder_outputs_train, self.decoder_state_train, self.decoder_context_state_train) = (seq2seq.dynamic_rnn_decoder( cell=self.decoder_cell, decoder_fn=decoder_fn_train, inputs=self.decoder_train_inputs_embedded, sequence_length=self.decoder_train_length, time_major=True, scope=scope, )) self.decoder_outputs_train = tf.nn.dropout( self.decoder_outputs_train, _keep_prob) self.decoder_logits_train = output_fn(self.decoder_outputs_train) # reusing the scope of training to use the same variables for inference scope.reuse_variables() (self.decoder_logits_inference, self.decoder_state_inference, self.decoder_context_state_inference) = ( seq2seq.dynamic_rnn_decoder( cell=self.decoder_cell, decoder_fn=decoder_fn_inference, time_major=True, scope=scope, )) self.decoder_prediction_inference = tf.argmax( self.decoder_logits_inference, axis=-1, name='decoder_prediction_inference')
def decoder_adv(self, max_twee_len): with self.graph.as_default(): with tf.variable_scope("Decoder") as scope: self.decoder_length = max_twee_len + 3 def output_fn(outputs): return tf.contrib.layers.linear(outputs, self.vocab_size, scope=scope) # self.decoder_cell = LSTMCell(self.decoder_hidden_nodes) self.decoder_cell = GRUCell(self.decoder_hidden_nodes) if not self.attention: decoder_train = seq2seq.simple_decoder_fn_train( encoder_state=self.encoder_final_state) decoder_inference = seq2seq.simple_decoder_fn_inference( output_fn=output_fn, encoder_state=self.encoder_final_state, embeddings=self.embed, start_of_sequence_id=self.EOS, end_of_sequence_id=self.EOS, maximum_length=self.decoder_length, num_decoder_symbols=self.vocab_size) else: # attention_states: size [batch_size, max_time, num_units] self.attention_states = tf.transpose( self.encoder_output, [1, 0, 2]) (self.attention_keys, self.attention_values, self.attention_score_fn, self.attention_construct_fn) = \ seq2seq.prepare_attention(attention_states = self.attention_states, attention_option = "bahdanau", num_units = self.decoder_hidden_nodes) decoder_fn_train = seq2seq.attention_decoder_fn_train( encoder_state=self.encoder_final_state, attention_keys=self.attention_keys, attention_values=self.attention_values, attention_score_fn=self.attention_score_fn, attention_construct_fn=self.attention_construct_fn, name="attention_decoder") decoder_fn_inference = seq2seq.attention_decoder_fn_inference( output_fn=output_fn, encoder_state=self.encoder_final_state, attention_keys=self.attention_keys, attention_values=self.attention_values, attention_score_fn=self.attention_score_fn, attention_construct_fn=self.attention_construct_fn, embeddings=self.embed, start_of_sequence_id=self.EOS, end_of_sequence_id=self.EOS, maximum_length= 23, #max_twee_len + 3, #tf.reduce_max(self.de_out_len) + 3, num_decoder_symbols=self.vocab_size) self.decoder_train_inputs_embedded = tf.nn.embedding_lookup( self.embed, self.decoder_train_input) (self.decoder_outputs_train, self.decoder_state_train, self.decoder_context_state_train) = ( seq2seq.dynamic_rnn_decoder( cell=self.decoder_cell, decoder_fn=decoder_fn_train, inputs=self.decoder_train_inputs_embedded, sequence_length=self.decoder_train_length, time_major=True, scope=scope)) self.decoder_logits_train = output_fn( self.decoder_outputs_train) self.decoder_prediction_train = tf.argmax( self.decoder_logits_train, axis=-1, name='decoder_prediction_train') scope.reuse_variables() (self.decoder_logits_inference, self.decoder_state_inference, self.decoder_context_state_inference) = ( seq2seq.dynamic_rnn_decoder( cell=self.decoder_cell, decoder_fn=decoder_fn_inference, time_major=True, scope=scope)) self.decoder_prediction_inference = tf.argmax( self.decoder_logits_inference, axis=-1, name='decoder_prediction_inference') return self.de_out, self.de_out_len, self.title_out, self.first_out, self.decoder_logits_train, \ self.decoder_prediction_train, self.loss_weights, self.decoder_train_targets, \ self.decoder_train_title, self.decoder_train_first, self.decoder_prediction_inference
def _init_decoder(self, output_projection): ''' Decoder phase ''' with tf.variable_scope("Decoder") as scope: self.decoder_inputs_embedded = tf.nn.embedding_lookup(self.lookup_matrix, self.decoder_inputs) (attention_keys, attention_values, attention_score_fn, attention_construct_fn) = seq2seq.prepare_attention( attention_states=self.attention_states, attention_option="bahdanau", num_units=self.args.h_units_decoder, ) # attention is added decoder_fn_train = seq2seq.attention_decoder_fn_train( encoder_state=self.encoder_state, attention_keys=attention_keys, attention_values=attention_values, attention_score_fn=attention_score_fn, attention_construct_fn=attention_construct_fn, name='attention_decoder' ) decoder_fn_inference = seq2seq.attention_decoder_fn_inference( output_fn=output_projection, encoder_state=self.encoder_state, attention_keys=attention_keys, attention_values=attention_values, attention_score_fn=attention_score_fn, attention_construct_fn=attention_construct_fn, embeddings=self.lookup_matrix, start_of_sequence_id=self.textData.goToken, end_of_sequence_id=self.textData.eosToken, maximum_length=tf.reduce_max(self.decoder_targets_length), num_decoder_symbols=self.textData.getVocabularySize(), ) # Check back here later...the hidden size of decoder_cell has to be in the same size of embedding layer? # !!! # decoder_outputs_train.shape = (batch_size, n_words, hidden_size) ( self.decoder_outputs_train, decoder_state_train, decoder_context_state_train) = seq2seq.dynamic_rnn_decoder( cell=self.decoder_cell, decoder_fn=decoder_fn_train, inputs=self.decoder_inputs_embedded, sequence_length=self.decoder_targets_length, time_major=False, scope=scope ) # self.decoder_logits_train = output_projection(self.decoder_outputs_train) # self.decoder_logits_train_trans = tf.reshape(self.decoder_outputs_train, [1,0,2]) self.decoder_logits_train = tf.transpose(tf.map_fn(output_projection, tf.transpose(self.decoder_outputs_train, [1, 0, 2])), [1, 0, 2]) self.decoder_prediction_train = tf.argmax(self.decoder_logits_train, axis=-1, name='decoder_prediction_train') # for both training and inference scope.reuse_variables() (self.decoder_logits_inference, decoder_state_inference, decoder_context_state_inference) = ( seq2seq.dynamic_rnn_decoder( cell=self.decoder_cell, decoder_fn=decoder_fn_inference, time_major=False, scope=scope ) )
def _init_decoder(self): with tf.variable_scope("Decoder") as scope: def output_fn(outputs): return tf.contrib.layers.linear( outputs, self.vocab_size, scope=scope ) #this is for calculatng outputs. In a greedy way if not self.attention: decoder_fn_train = seq2seq.simple_decoder_fn_train( encoder_state=self.encoder_state ) #This is the training function that we used in training dynamic_rnn_decoder #refer to https://github.com/tensorflow/tensorflow/blob/r1.0/tensorflow/contrib/seq2seq/python/ops/decoder_fn.py#L182 decoder_fn_inference = seq2seq.simple_decoder_fn_inference( #nference function for a sequence-to-sequence model. It should be used when dynamic_rnn_decoder is in the inference mode.final mode output_fn= output_fn, #this returns a decoder function . This function in used inside the dynamicRNN function encoder_state=self.encoder_state, embeddings=self.embedding_matrix, start_of_sequence_id=self.EOS, end_of_sequence_id=self.EOS, maximum_length=tf.reduce_max(self.encoder_inputs_length) + 3, num_decoder_symbols=self.vocab_size, ) else: # attention_states: size [batch_size, max_time, num_units] attention_states = tf.transpose(self.encoder_outputs, [ 1, 0, 2 ]) #take the attention status as the encorder hidden states ( attention_keys, #Each Encoder hidden status multiplied in fully conected way and list of size [num units*Max_time] attention_values, #this is attention encoder states attention_score_fn, #score function of the attention Different ways to compute attention scores If we input the decoder state , encoder hidden states this will out put the context vector attention_construct_fn ) = seq2seq.prepare_attention( #this contruct will Function to compute attention vectors. This will output the concatanaded context vector and the attention wuary then make it as a inpit attention_states=attention_states, attention_option="bahdanau", num_units=self.decoder_hidden_units, ) print("Prininting the number of units .......................") print(self.decoder_hidden_units) print( "Printing the shape of the attetniton values ......................**********************************************" ) print(attention_keys) print( "Printing the attention score function++++++++++++++++++++++++++++++++++++++++++++++++++++" ) print(attention_score_fn) #this function can basically initialize input state of the decoder the nthe attention and other stuff then this will be passed to dy_decorder #decorder_function train will take time, cell_state, cell_input, cell_output, context_state decoder_fn_train = seq2seq.attention_decoder_fn_train( #this is for training the dynamic decorder. This will take care of encoder_state=self. encoder_state, # final state. We take the biderection and concatanate it (c or h) attention_keys= attention_keys, # The transformation of each encoder outputs attention_values= attention_values, #attention encododr status attention_score_fn= attention_score_fn, #this will give a context vector attention_construct_fn= attention_construct_fn, #calculating above thinhs also output the hidden state name='attention_decoder') #What can we achieve by running decorder_fn_ ? done, next state, next input, emit output, next context state #here the emit_output or cell_output will give the output of cell after all atention - non lieanrity applied #this also give the hidden vector output which was concatanated with rnn output and attention vector . Actually concatanated goes throug a linear unit #next_input = array_ops.concat([cell_input, attention], 1) #next cell input #context_state - this will modify when using the beam search #what is the contect state in decorder_fn inside the return funfction of the decorder fn train #the following function is same as the above but the only difference is it's use this in the inference .This has a greedy output #in the inference model cell_output = output_fn(cell_output) . Which means we get logits #next_input = array_ops.concat([cell_input, attention], 1) decoder_fn_inference = seq2seq.attention_decoder_fn_inference( #this is used in the inference model output_fn= output_fn, #this will predict the output and the narcmax after that attention will be concatenaded encoder_state=self.encoder_state, attention_keys=attention_keys, attention_values=attention_values, attention_score_fn=attention_score_fn, #doing same attention_construct_fn=attention_construct_fn, embeddings=self.embedding_matrix, start_of_sequence_id=self.EOS, end_of_sequence_id=self.EOS, maximum_length=tf.reduce_max(self.encoder_inputs_length) + 3, num_decoder_symbols=self.vocab_size, ) #following function is to do all the decodinf with the helop of above functions #this can use in traning or inferense . But we need two separate finctions for trainin and iference #What is this context_state_train : one way to diversify the inference output is to use a stochastic decoder_fn, in which case one would want to store the decoded outputs, not just the RNN outputs. This can be done by maintaining a TensorArray in context_state and storing the decoded output of each iteration therein ( self. decoder_outputs_train, #outputs from the eacah cell [batch_size, max_time, cell.output_size] self. decoder_state_train, #The final state and will be shaped [batch_size, cell.state_size] self.decoder_context_state_train ) = ( #described above seq2seq.dynamic_rnn_decoder( cell=self.decoder_cell, decoder_fn= decoder_fn_train, #decoder_fn allows modeling of early stopping, output, state, and next input and context. inputs=self. decoder_train_inputs_embedded, #inputs to the decoder in the training #in the raning time only sequence_length=self. decoder_train_length, #sequence_length is needed at training time, i.e., when inputs is not None, for dynamic unrolling. At test time, when inputs is None, sequence_length is not needed. time_major= True, #input and output shape should be in [max_time, batch_size, ...] scope=scope, )) self.decoder_logits_train = output_fn( self.decoder_outputs_train ) #take the final output hidden status and run them throgh linearl layer #get the argmax self.decoder_prediction_train = tf.argmax( self.decoder_logits_train, axis=-1, name='decoder_prediction_train') scope.reuse_variables() ( self. decoder_logits_inference, #same as above but no input provided. This will take the predicted things as inputs self.decoder_state_inference, self.decoder_context_state_inference) = ( seq2seq.dynamic_rnn_decoder( cell=self.decoder_cell, decoder_fn= decoder_fn_inference, #difference decorder fucntion time_major=True, scope=scope, )) self.decoder_prediction_inference = tf.argmax( self.decoder_logits_inference, axis=-1, name='decoder_prediction_inference' ) #predicted output at the each time step
def _init_decoder(self): with tf.variable_scope("Decoder") as scope: def output_fn(outputs): return tc.layers.fully_connected(outputs, self.output_symbol_size, activation_fn=None, scope=scope) if not self.attention: decoder_fn_train = seq2seq.simple_decoder_fn_train( encoder_state=self.encoder_state) decoder_fn_inference = seq2seq.simple_decoder_fn_inference( output_fn=output_fn, encoder_state=self.encoder_state, embeddings=self.embedding_matrix, start_of_sequence_id=self.EOS, end_of_sequence_id=self.EOS, maximum_length=tf.reduce_max(self.encoder_inputs_length), num_decoder_symbols=self.output_symbol_size) else: (attention_keys, attention_values, attention_score_fn, attention_construct_fn) = seq2seq.prepare_attention( attention_states=self.encoder_outputs, attention_option="bahdanau", num_units=self.decoder_hidden_units, ) decoder_fn_train = seq2seq.attention_decoder_fn_train( encoder_state=self.encoder_state, attention_keys=attention_keys, attention_values=attention_values, attention_score_fn=attention_score_fn, attention_construct_fn=attention_construct_fn, name='attention_decoder') decoder_fn_inference = seq2seq.attention_decoder_fn_inference( output_fn=output_fn, encoder_state=self.encoder_state, attention_keys=attention_keys, attention_values=attention_values, attention_score_fn=attention_score_fn, attention_construct_fn=attention_construct_fn, embeddings=self.embedding_matrix, start_of_sequence_id=self.EOS, end_of_sequence_id=self.EOS, maximum_length=tf.reduce_max(self.encoder_inputs_length), num_decoder_symbols=self.output_symbol_size, ) if self.is_training: (self.decoder_outputs_train, self.decoder_state_train, self.decoder_context_state_train) = ( seq2seq.dynamic_rnn_decoder( cell=self.decoder_cell, decoder_fn=decoder_fn_train, inputs=self.decoder_train_inputs_embedded, sequence_length=self.decoder_train_length, time_major=False, scope=scope, )) self.decoder_logits_train = output_fn( self.decoder_outputs_train) self.decoder_prediction_train = tf.argmax( self.decoder_logits_train, axis=-1, name='decoder_prediction_train') scope.reuse_variables() (self.decoder_logits_inference, self.decoder_state_inference, self.decoder_context_state_inference) = ( seq2seq.dynamic_rnn_decoder( cell=self.decoder_cell, decoder_fn=decoder_fn_inference, time_major=False, scope=scope, )) self.decoder_prediction_inference = tf.argmax( self.decoder_logits_inference, axis=-1, name='decoder_prediction_inference')
def _init_decoder(self, forward_only): with tf.variable_scope("decoder") as scope: def output_fn(outputs): return tf.contrib.layers.linear(outputs, self.target_vocab_size, scope=scope) self.attention = True if not self.attention: if forward_only: decoder_fn = seq2seq.simple_decoder_fn_inference( output_fn=output_fn, encoder_state=self.encoder_state, embeddings=self.dec_embedding_matrix, start_of_sequence_id=model_config.GO_ID, end_of_sequence_id=model_config.EOS_ID, maximum_length=self.buckets[-1][1], num_decoder_symbols=self.target_vocab_size, ) (self.decoder_outputs, self.decoder_state, self.decoder_context_state) = ( seq2seq.dynamic_rnn_decoder( cell=self.decoder_cell, decoder_fn=decoder_fn, time_major=True, scope=scope, )) else: decoder_fn = seq2seq.simple_decoder_fn_train( encoder_state=self.encoder_state) (self.decoder_outputs, self.decoder_state, self.decoder_context_state) = ( seq2seq.dynamic_rnn_decoder( cell=self.decoder_cell, decoder_fn=decoder_fn, inputs=self.decoder_inputs_embedded, sequence_length=self.decoder_inputs_length, time_major=True, scope=scope, )) else: # attention_states: size [batch_size, max_time, num_units] attention_states = tf.transpose(self.encoder_outputs, [1, 0, 2]) (attention_keys, attention_values, attention_score_fn, attention_construct_fn) = (seq2seq.prepare_attention( attention_states=attention_states, attention_option="bahdanau", num_units=self.dec_hidden_size)) if forward_only: decoder_fn = seq2seq.attention_decoder_fn_inference( output_fn=output_fn, encoder_state=self.encoder_state, attention_keys=attention_keys, attention_values=attention_values, attention_score_fn=attention_score_fn, attention_construct_fn=attention_construct_fn, embeddings=self.dec_embedding_matrix, start_of_sequence_id=model_config.GO_ID, end_of_sequence_id=model_config.EOS_ID, maximum_length=self.buckets[-1][1], num_decoder_symbols=self.target_vocab_size, ) (self.decoder_outputs, self.decoder_state, self.decoder_context_state) = ( seq2seq.dynamic_rnn_decoder( cell=self.decoder_cell, decoder_fn=decoder_fn, time_major=True, scope=scope, )) else: decoder_fn = seq2seq.attention_decoder_fn_train( encoder_state=self.encoder_state, attention_keys=attention_keys, attention_values=attention_values, attention_score_fn=attention_score_fn, attention_construct_fn=attention_construct_fn, name='attention_decoder') (self.decoder_outputs, self.decoder_state, self.decoder_context_state) = ( seq2seq.dynamic_rnn_decoder( cell=self.decoder_cell, decoder_fn=decoder_fn, inputs=self.decoder_inputs_embedded, sequence_length=self.decoder_inputs_length, time_major=True, scope=scope, )) if not forward_only: self.decoder_logits = output_fn(self.decoder_outputs) else: self.decoder_logits = self.decoder_outputs self.decoder_prediction = tf.argmax(self.decoder_logits, axis=-1, name='decoder_prediction') logits = tf.transpose(self.decoder_logits, [1, 0, 2]) targets = tf.transpose(self.decoder_targets, [1, 0]) if not forward_only: self.loss = seq2seq.sequence_loss(logits=logits, targets=targets, weights=self.target_weights)
def __init__(self, vocab_size, embed_size, num_layers, hidden_size, eos, max_len, initial_embed=None): super(Seq2SeqAttn, self).__init__() self.vocab_size = vocab_size self.embed_size = embed_size self.num_layers = num_layers self.hidden_size = hidden_size # TO DO self.EOS = eos self.PAD = 0 # placeholders self.encoder_inputs = tf.placeholder(shape=(None, None), dtype=tf.int32, name='encoder_inputs') self.encoder_inputs_length = tf.placeholder( shape=(None, ), dtype=tf.int32, name='encoder_inputs_length') self.decoder_targets = tf.placeholder(shape=(None, None), dtype=tf.int32, name='decoder_targets') self.rewards = tf.placeholder(shape=(None, ), dtype=tf.float32, name='rewards') # self.decoder_targets_length = tf.placeholder(shape=(None,), dtype=tf.int32, name='decoder_targets_length') # LSTM cell cells = [] for _ in range(num_layers): cells.append(tf.contrib.rnn.LSTMCell(hidden_size)) self.encoder_cell = MultiRNNCell(cells) self.decoder_cell = MultiRNNCell(cells) # decoder train feeds with tf.name_scope('decoder_feeds'): sequence_size, batch_size = tf.unstack( tf.shape(self.decoder_targets)) EOS_SLICE = tf.ones([1, batch_size], dtype=tf.int32) * self.EOS PAD_SLICE = tf.ones([1, batch_size], dtype=tf.int32) * self.PAD self.decoder_train_inputs = tf.concat( [EOS_SLICE, self.decoder_targets], axis=0) # self.decoder_train_length = self.decoder_targets_length + 1 # one for EOS self.decoder_train_length = tf.cast( tf.ones(shape=(batch_size, )) * (max_len + 1), tf.int32) decoder_train_targets = tf.concat( [self.decoder_targets, PAD_SLICE], axis=0) # (seq_len + 1) * batch_size # decoder_train_targets_seq_len, _ = tf.unstack(tf.shape(decoder_train_targets)) # decoder_train_targets_eos_mask = tf.one_hot( # self.decoder_targets_length, # decoder_train_targets_seq_len, # on_value=self.EOS, # off_value=self.PAD, # dtype=tf.int32 # ) # batch_size * (seq_len + 1) # decoder_train_targets_eos_mask = tf.transpose(decoder_train_targets_eos_mask, [1, 0]) # decoder_train_targets = tf.add(decoder_train_targets, decoder_train_targets_eos_mask) self.decoder_train_targets = decoder_train_targets # (seq_len + 1) * batch_size with EOS at the end of each sentence # loss_weights loss_weights = tf.cast( tf.cast(self.decoder_train_targets, tf.bool), tf.float32) self.loss_weights = tf.transpose(loss_weights, perm=[1, 0]) # self.loss_weights = tf.ones([batch_size, tf.reduce_max(self.decoder_train_length)], dtype=tf.float32, name="loss_weights") # embedding layer with tf.variable_scope('embedding'): # if initial_embed: self.embedding = tf.Variable(initial_embed, name='matrix', dtype=tf.float32) # else: # self.embedding = tf.Variable(tf.random_normal([vocab_size, embed_size], - 0.5 / embed_size, 0.5 / embed_size), name='matrix', dtype=tf.float32) self.encoder_inputs_embedded = tf.nn.embedding_lookup( self.embedding, self.encoder_inputs) self.decoder_train_inputs_embedded = tf.nn.embedding_lookup( self.embedding, self.decoder_train_inputs) # encoder with tf.variable_scope('encoder'): self.encoder_outputs, self.encoder_state = tf.nn.dynamic_rnn( cell=self.encoder_cell, inputs=self.encoder_inputs_embedded, sequence_length=self.encoder_inputs_length, time_major=True, dtype=tf.float32) # decoder with tf.variable_scope('decoder') as scope: def output_fn(outputs): return tf.contrib.layers.linear(outputs, self.vocab_size, scope=scope) attention_states = tf.transpose( self.encoder_outputs, [1, 0, 2]) # batch_size * seq_len * hidden_size (attention_keys, attention_values, attention_score_fn, attention_construct_fn) = seq2seq.prepare_attention( attention_states=attention_states, attention_option="bahdanau", num_units=self.hidden_size, ) decoder_fn_train = seq2seq.attention_decoder_fn_train( encoder_state=self.encoder_state, attention_keys=attention_keys, attention_values=attention_values, attention_score_fn=attention_score_fn, attention_construct_fn=attention_construct_fn, name='attention_decoder') decoder_fn_inference = seq2seq.attention_decoder_fn_inference( output_fn=output_fn, encoder_state=self.encoder_state, attention_keys=attention_keys, attention_values=attention_values, attention_score_fn=attention_score_fn, attention_construct_fn=attention_construct_fn, embeddings=self.embedding, start_of_sequence_id=self.EOS, end_of_sequence_id=self.EOS, maximum_length=35, num_decoder_symbols=self.vocab_size, ) self.decoder_outputs_train, self.decoder_state_train, self.decoder_context_state_train = seq2seq.dynamic_rnn_decoder( cell=self.decoder_cell, decoder_fn=decoder_fn_train, inputs=self.decoder_train_inputs_embedded, sequence_length=self.decoder_train_length, time_major=True, scope=scope, ) self.decoder_logits_train = output_fn(self.decoder_outputs_train) self.decoder_prediction_train = tf.argmax( self.decoder_logits_train, axis=-1, name='decoder_prediction_train') scope.reuse_variables() self.decoder_logits_inference, self.decoder_state_inference, self.decoder_context_state_inference = ( seq2seq.dynamic_rnn_decoder( cell=self.decoder_cell, decoder_fn=decoder_fn_inference, time_major=True, scope=scope, )) self.decoder_prediction_inference = tf.argmax( self.decoder_logits_inference, axis=-1, name='decoder_prediction_inference') # optimizer with tf.name_scope('optimizer'): self.global_step = tf.Variable(0, name="global_step", trainable=False) # self.policy_step = tf.Variable(0, name='policy_step', trainable=False) logits = tf.transpose( self.decoder_logits_train, [1, 0, 2]) # batch_size * sequence_length * vocab_size targets = tf.transpose(self.decoder_train_targets, [1, 0]) logits_inference = tf.transpose(self.decoder_logits_inference, [1, 0, 2]) output_prob = tf.reduce_max(tf.nn.softmax(logits_inference), axis=2) # batch_size * seq_len seq_log_prob = tf.reduce_sum(tf.log(output_prob), axis=1) self.policy_loss = -tf.reduce_sum(self.rewards * seq_log_prob) self.policy_op = tf.train.AdamOptimizer().minimize( self.policy_loss) self.loss = seq2seq.sequence_loss(logits=logits, targets=targets, weights=self.loss_weights) self.train_op = tf.train.AdamOptimizer().minimize( self.loss, global_step=self.global_step)
def _build_graph(self): # required only for training self.targets = tf.placeholder(shape=(None, None), dtype=tf.int32, name="decoder_inputs") self.targets_length = tf.placeholder(shape=(None, ), dtype=tf.int32, name="decoder_inputs_length") self.global_step = tf.Variable(0, name="global_step", trainable=False) with tf.name_scope("DecoderTrainFeed"): sequence_size, batch_size = tf.unstack(tf.shape(self.targets)) EOS_SLICE = tf.ones([1, batch_size], dtype=tf.int32) * self.EOS PAD_SLICE = tf.ones([1, batch_size], dtype=tf.int32) * self.PAD self.train_inputs = tf.concat([EOS_SLICE, self.targets], axis=0) self.train_length = self.targets_length + 1 train_targets = tf.concat([self.targets, PAD_SLICE], axis=0) train_targets_seq_len, _ = tf.unstack(tf.shape(train_targets)) train_targets_eos_mask = tf.one_hot(self.train_length - 1, train_targets_seq_len, on_value=self.EOS, off_value=self.PAD, dtype=tf.int32) train_targets_eos_mask = tf.transpose(train_targets_eos_mask, [1, 0]) # hacky way using one_hot to put EOS symbol at the end of target sequence train_targets = tf.add(train_targets, train_targets_eos_mask) self.train_targets = train_targets self.loss_weights = tf.ones( [batch_size, tf.reduce_max(self.train_length)], dtype=tf.float32, name="loss_weights") with tf.variable_scope("embedding") as scope: self.inputs_embedded = tf.nn.embedding_lookup( self.embedding_matrix, self.train_inputs) with tf.variable_scope("Decoder") as scope: def logits_fn(outputs): return layers.linear(outputs, self.vocab_size, scope=scope) if not self.attention: train_fn = seq2seq.simple_decoder_fn_train( encoder_state=self.encoder_state) inference_fn = seq2seq.simple_decoder_fn_inference( output_fn=logits_fn, encoder_state=self.encoder_state, embeddings=self.embedding_matrix, start_of_sequence_id=self.EOS, end_of_sequence_id=self.EOS, maximum_length=tf.reduce_max(self.encoder_inputs_length) + 3, num_decoder_symbols=self.vocab_size) else: # attention_states: size [batch_size, max_time, num_units] attention_states = tf.transpose(self.encoder_outputs, [1, 0, 2]) (attention_keys, attention_values, attention_score_fn, attention_construct_fn) = seq2seq.prepare_attention( attention_states=attention_states, attention_option="bahdanau", num_units=self.decoder_hidden_units) train_fn = seq2seq.attention_decoder_fn_train( encoder_state=self.encoder_state, attention_keys=attention_keys, attention_values=attention_values, attention_score_fn=attention_score_fn, attention_construct_fn=attention_construct_fn, name="decoder_attention") inference_fn = seq2seq.attention_decoder_fn_inference( output_fn=logits_fn, encoder_state=self.encoder_state, attention_keys=attention_keys, attention_values=attention_values, attention_score_fn=attention_score_fn, attention_construct_fn=attention_construct_fn, embeddings=self.embedding_matrix, start_of_sequence_id=self.EOS, end_of_sequence_id=self.EOS, maximum_length=tf.reduce_max(self.encoder_inputs_length) + 3, num_decoder_symbols=self.vocab_size) (self.train_outputs, self.train_state, self.train_context_state) = seq2seq.dynamic_rnn_decoder( cell=self.cell, decoder_fn=train_fn, inputs=self.inputs_embedded, sequence_length=self.train_length, time_major=True, scope=scope) self.train_logits = logits_fn(self.train_outputs) self.train_prediction = tf.argmax(self.train_logits, axis=-1, name="train_prediction") self.train_prediction_probabilities = tf.nn.softmax( self.train_logits, dim=-1, name="train_prediction_probabilities") scope.reuse_variables() (self.inference_logits, self.inference_state, self.inference_context_state) = seq2seq.dynamic_rnn_decoder( cell=self.cell, decoder_fn=inference_fn, time_major=True, scope=scope) self.inference_prediction = tf.argmax(self.inference_logits, axis=-1, name="inference_prediction") self.inference_prediction_probabilities = tf.nn.softmax( self.train_logits, dim=-1, name="inference_prediction_probabilities")