def call(self, inputs, state): cell_inputs = self.cell_input_fn(inputs, state.attention) cell_state = state.cell_state cell_output, next_cell_state = self.cell(cell_inputs, cell_state) cell_output = tf.identity(cell_output, name="checked_cell_output") previous_alignment = state.alignments previous_alignment_history = state.alignment_history previous_attention_history = state.attention_history attention_mechanism = self.attention_mechanism attention, alignments = _compute_attention(attention_mechanism, cell_output, previous_alignment, self.attention_layer) alignment_history = previous_alignment_history.write( state.time, alignments) if self.use_alignment_history else () attention_history = previous_attention_history.write( state.time, attention) next_state = MyAttentionWrapperState( time=state.time + 1, cell_state=next_cell_state, attention=attention, attention_history=attention_history, alignments=alignments, alignment_history=alignment_history) return cell_output, next_state
def call(self, inputs, state): if not isinstance(state, AttentionWrapperState_v2): raise TypeError( "Expected state to be instance of AttentionWrapperState. " "Received type %s instead." % type(state)) if self._is_multi: previous_alignment_history = state.alignment_history else: previous_alignment_history = [state.alignment_history] all_alignments = [] all_attentions = [] maybe_all_histories = [] for i, attention_mechanism in enumerate(self._attention_mechanisms): attention, alignments, next_attention_state = _compute_attention( attention_mechanism, inputs, attention_state=None, attention_layer=self._attention_layers[i] if self._attention_layers else None) alignment_history = previous_alignment_history[i].write( state.time, alignments) if self._alignment_history else () all_alignments.append(alignments) all_attentions.append(attention) maybe_all_histories.append(alignment_history) attention = array_ops.concat(all_attentions, 1) cell_inputs = self._cell_input_fn(inputs, state.attention) cell_state = state.cell_state cell_output, next_cell_state = self._cell(cell_inputs, cell_state) cell_batch_size = (cell_output.shape[0].value or array_ops.shape(cell_output)[0]) error_message = ( "When applying AttentionWrapper %s: " % self.name + "Non-matching batch sizes between the memory " "(encoder output) and the query (decoder output). Are you using " "the BeamSearchDecoder? You may need to tile your memory input via " "the tf.contrib.seq2seq.tile_batch function with argument " "multiple=beam_width.") with ops.control_dependencies( self._batch_size_checks(cell_batch_size, error_message)): cell_output = array_ops.identity(cell_output, name="checked_cell_output") next_state = AttentionWrapperState_v2( time=state.time + 1, cell_state=next_cell_state, attention=attention, alignments=self._item_or_tuple(all_alignments), alignment_history=self._item_or_tuple(maybe_all_histories), last_choice=state.last_choice) if self._output_attention: return attention, next_state else: return cell_output, next_state
def __call__(self, inputs, state): #Information bottleneck (essential for learning attention) prenet_output = self._prenet(inputs) #Concat context vector and prenet output to form LSTM cells input (input feeding) LSTM_input = tf.concat([prenet_output, state.attention], axis=-1) #Unidirectional LSTM layers LSTM_output, next_cell_state = self._cell(LSTM_input, state.cell_state) #Compute the attention (context) vector and alignments using #the new decoder cell hidden state as query vector #and cumulative alignments to extract location features #The choice of the new cell hidden state (s_{i}) of the last #decoder RNN Cell is based on Luong et Al. (2015): #https://arxiv.org/pdf/1508.04025.pdf previous_alignments = state.alignments previous_alignment_history = state.alignment_history context_vector, alignments, cumulated_alignments = _compute_attention( self._attention_mechanism, LSTM_output, previous_alignments, attention_layer=None) #Concat LSTM outputs and context vector to form projections inputs projections_input = tf.concat([LSTM_output, context_vector], axis=-1) #Compute predicted frames and predicted <stop_token> cell_outputs = self._frame_projection(projections_input) stop_tokens = self._stop_projection(projections_input) #mask attention computed for decoding steps where sequence is already finished #this is purely for visual purposes and will not affect the training of the model #we don't pay much attention to the alignments of the output paddings if we impute #the decoder outputs beyond the end of sequence. if self._mask_finished: finished = tf.cast(state.finished * tf.ones(tf.shape(alignments)), tf.bool) mask = tf.zeros(tf.shape(alignments)) masked_alignments = tf.where(finished, mask, alignments) else: masked_alignments = alignments #Save alignment history alignment_history = previous_alignment_history.write( state.time, masked_alignments) #Prepare next decoder state next_state = TacotronDecoderCellState( time=state.time + 1, cell_state=next_cell_state, attention=context_vector, alignments=cumulated_alignments, alignment_history=alignment_history, finished=state.finished) return (cell_outputs, stop_tokens), next_state
def call(self, inputs, state): # 注意这个wrapper改变了attention操作顺序,注意与原版的区别 if not isinstance(state, seq2seq.AttentionWrapperState): raise TypeError( "Expected state to be instance of AttentionWrapperState. " "Received type %s instead." % type(state)) if self._is_multi: previous_alignments = state.alignments previous_alignment_history = state.alignment_history else: previous_alignments = [state.alignments] previous_alignment_history = [state.alignment_history] all_alignments = [] all_attentions = [] all_histories = [] all_attentions_state = [] for i, attention_mechanism in enumerate(self._attention_mechanisms): if isinstance(self._cell, rnn.LSTMCell): rnn_cell_state = state.cell_state.h else: rnn_cell_state = state.cell_state # Some of code are changed based on https://github.com/bgshih/aster/issues/2 attention, alignments, attention_state = _compute_attention( attention_mechanism, rnn_cell_state, previous_alignments[i], self._attention_layers[i] if self._attention_layers else None) alignment_history = previous_alignment_history[i].write( state.time, alignments) if self._alignment_history else () all_alignments.append(alignments) all_histories.append(alignment_history) all_attentions.append(attention) all_attentions_state.append(attention_state) attention = array_ops.concat(all_attentions, 1) cell_inputs = self._cell_input_fn(inputs, attention) cell_output, next_cell_state = self._cell(cell_inputs, state.cell_state) next_state = seq2seq.AttentionWrapperState( time=state.time + 1, cell_state=next_cell_state, attention=attention, alignments=self._item_or_tuple(all_alignments), alignment_history=self._item_or_tuple(all_histories), attention_state=self._item_or_tuple(all_attentions_state)) if self._output_attention: return attention, next_state else: return cell_output, next_state
def __call__(self, inputs, state): #本时刻的真实输出y,decoder对上一时刻输出的状态。一起预测下一时刻 drop_rate = 0.5 if self._training else 0.0 #设置dropout #对输入预处理 with tf.variable_scope( 'decoder_prenet'): # [N, T_in, prenet_depths[-1]=128] for i, size in enumerate([256, 128]): dense = tf.keras.layers.Dense(units=size, activation=tf.nn.relu, name='dense_%d' % (i + 1))(inputs) inputs = tf.keras.layers.Dropout(rate=drop_rate, name='dropout_%d' % (i + 1))( dense, training=self._training) #加入注意力特征 rnn_input = tf.concat([inputs, state.attention], axis=-1) #经过一个全连接变换。再传入解码器rnn中 rnn_output, next_cell_state = self._cell( tf.keras.layers.Dense(256)(rnn_input), state.cell_state) #计算本次注意力 context_vector, alignments, cumulated_alignments = attention_wrapper._compute_attention( self._attention_mechanism, rnn_output, state.alignments, None) #state.alignments为上一次的累计注意力 #保存历史alignment(与原始的AttentionWrapper一致) alignment_history = state.alignment_history.write( state.time, alignments) #返回本次的wrapper状态 next_state = tf.contrib.seq2seq.AttentionWrapperState( time=state.time + 1, cell_state=next_cell_state, attention=context_vector, alignments=cumulated_alignments, alignment_history=alignment_history, attention_state=state.attention_state) #计算本次结果:将解码器输出与注意力结果concat起来。作为最终的输入 projections_input = tf.concat([rnn_output, context_vector], axis=-1) #两个全连接分别预测输出的下一个结果和停止标志<stop_token> cell_outputs = self._frame_projection( projections_input) #得到下一次outputs_per_step个帧的mel特征 stop_tokens = self._stop_projection(projections_input) if self._training == False: stop_tokens = tf.nn.sigmoid(stop_tokens) return (cell_outputs, stop_tokens), next_state
def __call__(self, inputs, state): #Pass the previously predicted frame through the prenet prenet_output = self._prenet(inputs) #Compute the attention (context) vector and alignments using #the top layer hidden state as query vector #and previous alignments to extract location features #Based on Luong et Al. (2015) for the top layer choice: #https://arxiv.org/pdf/1508.04025.pdf first_lstm_state, last_lstm_state = state.cell_state last_hidden_state = last_lstm_state.h previous_alignments = state.alignments previous_alignment_history = state.alignment_history context_vector, alignments, _ = _compute_attention( self._attention_mechanism, last_hidden_state, previous_alignments, attention_layer=None) #Concat context vector and prenet output to form LSTM cells input LSTM_input = tf.concat([prenet_output, context_vector], axis=-1) #Unidirectional LSTM layers LSTM_output, next_cell_state = self._cell(LSTM_input, state.cell_state) #Concat LSTM outputs and context vector to form projections inputs projections_input = tf.concat([LSTM_output, context_vector], axis=-1) #Compute predicted frames and predicted <stop_token> cell_outputs = self._frame_projection(projections_input) stop_tokens = self._stop_projection(projections_input) #Save alignment history alignment_history = previous_alignment_history.write( state.time, alignments) #Prepare next decoder state next_state = TacotronDecoderCellState( time=state.time + 1, cell_state=next_cell_state, attention=context_vector, alignments=alignments, alignment_history=alignment_history) return (cell_outputs, stop_tokens), next_state
def call(self, inputs, state): """Perform a step of attention-wrapped RNN. - Step 1: Mix the `inputs` and previous step's `attention` output via `cell_input_fn`. - Step 2: Call the wrapped `cell` with this input and its previous state. - Step 3: Score the cell's output with `attention_mechanism`. - Step 4: Calculate the alignments by passing the score through the `normalizer`. - Step 5: Calculate the context vector as the inner product between the alignments and the attention_mechanism's values (memory). - Step 6: Calculate the attention output by concatenating the cell output and context through the attention layer (a linear layer with `attention_layer_size` outputs). Args: inputs: (Possibly nested tuple of) Tensor, the input at this time step. state: An instance of `AttentionWrapperState` containing tensors from the previous time step. Returns: A tuple `(attention_or_cell_output, next_state)`, where: - `attention_or_cell_output` depending on `output_attention`. - `next_state` is an instance of `AttentionWrapperState` containing the state calculated at this time step. Raises: TypeError: If `state` is not an instance of `AttentionWrapperState`. """ if not isinstance(state, tf.contrib.seq2seq.AttentionWrapperState): raise TypeError( "Expected state to be instance of AttentionWrapperState. " "Received type %s instead." % type(state)) # Step 1: Calculate the true inputs to the cell based on the # previous attention value. cell_inputs = self._cell_input_fn(inputs, state.attention) cell_state = state.cell_state cell_output, next_cell_state = self._cell(cell_inputs, cell_state) cell_batch_size = (cell_output.shape[0].value or tf.shape(cell_output)[0]) error_message = ( "When applying AttentionWrapper %s: " % self.name + "Non-matching batch sizes between the memory " "(encoder output) and the query (decoder output). Are you using " "the BeamSearchDecoder? You may need to tile your memory input via " "the tf.contrib.seq2seq.tile_batch function with argument " "multiple=beam_width.") with tf.control_dependencies( self._batch_size_checks(cell_batch_size, error_message)): cell_output = tf.identity(cell_output, name="checked_cell_output") if self._is_multi: previous_alignments = state.alignments previous_alignment_history = state.alignment_history else: previous_alignments = [state.alignments] previous_alignment_history = [state.alignment_history] all_alignments = [] all_attentions = [] all_histories = [] for i, attention_mechanism in enumerate(self._attention_mechanisms): if self.coverage: # if we use coverage mode, previous alignments is coverage vector # alignment history stack has shape: decoder time * batch * atten_len # convert it to coverage vector previous_alignments[i] = tf.cond( previous_alignment_history[i].size() > 0, lambda: tf.reduce_sum(tf.transpose( previous_alignment_history[i].stack(), [1, 2, 0]), axis=2), lambda: tf.zeros_like(previous_alignments[i])) # debug # previous_alignments[i] = tf.Print(previous_alignments[i],[previous_alignment_history[i].size(), tf.shape(previous_alignments[i]),previous_alignments[i]],message="atten wrapper:") attention, alignments, _ = _compute_attention( attention_mechanism, cell_output, previous_alignments[i], self._attention_layers[i] if self._attention_layers else None) alignment_history = previous_alignment_history[i].write( state.time, alignments) if self._alignment_history else () all_alignments.append(alignments) all_histories.append(alignment_history) all_attentions.append(attention) attention = tf.concat(all_attentions, 1) next_state = seq2seq.AttentionWrapperState( time=state.time + 1, cell_state=next_cell_state, attention=attention, alignments=self._item_or_tuple(all_alignments), alignment_history=self._item_or_tuple(all_histories), attention_state=state.attention_state) if self._output_attention: return attention, next_state else: return cell_output, next_state
def call(self, inputs, state): """Perform a step of attention-wrapped RNN. - Step 1: Mix the `inputs` and previous step's `attention` output via `cell_input_fn`. - Step 2: Call the wrapped `cell` with this input and its previous state. - Step 3: Score the cell's output with `attention_mechanism`. - Step 4: Calculate the alignments by passing the score through the `normalizer`. - Step 5: Calculate the context vector as the inner product between the alignments and the attention_mechanism's values (memory). - Step 6: Calculate the attention output by concatenating the cell output and context through the attention layer (a linear layer with `attention_layer_size` outputs). Args: inputs: (Possibly nested tuple of) Tensor, the input at this time step. state: An instance of `AttentionWrapperState` containing tensors from the previous time step. Returns: A tuple `(attention_or_cell_output, next_state)`, where: - `attention_or_cell_output` depending on `output_attention`. - `next_state` is an instance of `AttentionWrapperState` containing the state calculated at this time step. Raises: TypeError: If `state` is not an instance of `AttentionWrapperState`. """ if not isinstance(state, HGFUAttentionWrapperState): raise TypeError("Expected state to be instance of AttentionWrapperState. " "Received type %s instead." % type(state)) # Step 1: Calculate the true inputs to the cell based on the # previous attention value. cell_state = state.cell_state model_selector_openness = state.model_selector_openness copy_alignments = state.copy_alignments fact_alignments = state.fact_alignments fact_memory_alignments = state.fact_memory_alignments batch_size = tf.shape(self._lengths_for_fanct_candidates)[0] next_last_id = state.last_id decoding_mask = state.decoding_mask # [batch, dim] std_cell_inputs = tf.concat([self._cell_input_fn(inputs, state.attention)], -1) with tf.variable_scope('std_gru'): outputs_std, cell_state = self._std_cell(std_cell_inputs, cell_state) cell_output = outputs_std next_cell_state = cell_state fact_embedding = self._fact_candidates maximium_candidate_num = tf.shape(fact_embedding)[1] # [batch, embed] dynamic_inputs_1 = tf.concat([outputs_std, inputs], -1) entity_update_scores_p1 = tf.layers.dense(dynamic_inputs_1, units=self._sim_vec_dim, activation=tf.nn.tanh, name='entity_query_projection') entity_update_scores_p2 =self._fact_candidates entity_update_scores_p1 = tf.expand_dims(entity_update_scores_p1, 1) entity_update_scores_p1 = tf.tile(entity_update_scores_p1, [1, maximium_candidate_num, 1]) entity_update_scores = tf.reduce_sum(entity_update_scores_p1 * entity_update_scores_p2, -1) entity_update_mask = tf.sequence_mask(self._lengths_for_fanct_candidates, dtype=tf.float32) entity_update_mask = (1.0 - entity_update_mask) * -1e10 entity_update_scores += entity_update_mask entity_logits = entity_update_scores cell_batch_size = ( cell_output.shape[0].value or array_ops.shape(cell_output)[0]) error_message = ( "When applying AttentionWrapper %s: " % self.name + "Non-matching batch sizes between the memory " "(encoder output) and the query (decoder output). Are you using " "the BeamSearchDecoder? You may need to tile your memory input via " "the tf.contrib.seq2seq.tile_batch function with argument " "multiple=beam_width.") with ops.control_dependencies( self._batch_size_checks(cell_batch_size, error_message)): cell_output = array_ops.identity( cell_output, name="checked_cell_output") if self._is_multi: previous_attention_state = state.attention_state previous_alignment_history = state.alignment_history else: previous_attention_state = [state.attention_state] previous_alignment_history = [state.alignment_history] all_alignments = [] all_attentions = [] all_attention_states = [] maybe_all_histories = [] for i, attention_mechanism in enumerate(self._attention_mechanisms): attention, alignments, next_attention_state = attention_wrapper._compute_attention( attention_mechanism, cell_output, previous_attention_state[i], self._attention_layers[i] if self._attention_layers else None) alignment_history = previous_alignment_history[i].write( state.time, alignments) if self._alignment_history else () all_attention_states.append(next_attention_state) all_alignments.append(alignments) all_attentions.append(attention) maybe_all_histories.append(alignment_history) attention = array_ops.concat(all_attentions, 1) cell_output_org = cell_output common_word_inputs = tf.concat([cell_output, attention, inputs], -1) if self.mid_projection_dim > -1: common_word_inputs = tf.layers.dense(common_word_inputs, self.mid_projection_dim, tf.nn.elu) common_word_logits = self._common_word_projection(common_word_inputs) common_probs = tf.nn.softmax(common_word_logits, -1) if self._entity_predict_mode or self._copy_predict_mode: selector_mask = [1.0, 1.0, 1.0] if self._copy_predict_mode: # encoder_memroy [batch, seq_len, embedding] batch_num = tf.shape(self._encoder_memory)[0] max_encoder_len = tf.shape(self._encoder_memory)[1] # => [batch,seq_lem, embedding] copy_input_p1 = tf.concat([cell_output, attention, inputs], -1) copy_layer1_p1 = tf.layers.dense(copy_input_p1, units=self._sim_vec_dim, activation=tf.nn.tanh, name='copy_query') copy_layer1_p2 = self._transformed_encoder_memory copy_layer1_p1 = tf.tile(tf.expand_dims(copy_layer1_p1, 1), [1, max_encoder_len, 1]) copy_logits = tf.reduce_sum(copy_layer1_p1 * copy_layer1_p2, -1) copy_mask = tf.sequence_mask(self._encoder_memory_len, dtype=tf.float32) copy_mask = (1.0 - copy_mask) * -1e10 copy_logits += copy_mask copy_probs = tf.nn.softmax(copy_logits, -1) # Padding to fix len padding_num = self._copy_vocab_size - tf.reduce_max(self._encoder_memory_len) padding_probs = tf.zeros([batch_num, padding_num]) copy_probs = tf.concat([copy_probs, padding_probs], -1) else: copy_logits = tf.ones([batch_size, self._copy_vocab_size], tf.float32) * -1e10 copy_probs = tf.zeros([batch_size, self._copy_vocab_size], tf.float32) selector_mask[1] = 0.0 if self._entity_predict_mode: entity_logits = entity_logits # 一半一半 if self._kg_initial_goals is not None: if self.balance_gate is True: balance_selector_input = tf.concat([cell_output_org, attention, inputs], -1) balance_selector = tf.layers.dense(balance_selector_input, 1, activation=tf.nn.sigmoid, name='entity_balance_selector') entity_probs = tf.nn.softmax(entity_logits, -1) * balance_selector + self._kg_initial_goals * ( 1.0 - balance_selector) if self.k_openness_history: fact_alignments = fact_alignments.write(state.time, entity_probs) else: entity_probs = tf.nn.softmax(entity_logits, -1) * 0.5 + self._kg_initial_goals * 0.5 else: entity_probs = tf.nn.softmax(entity_logits, -1) entity_probs = tf.minimum(entity_probs, decoding_mask) else: selector_mask[2] = 0 entity_probs = tf.zeros([batch_size, self._entity_vocab_size], tf.float32) mode_selector_input = tf.concat([cell_output_org, attention, inputs], -1) layer1 = tf.layers.dense(mode_selector_input, self._sim_vec_dim, use_bias=True, activation=tf.nn.relu, name='selector_l1') common_selector = tf.layers.dense(layer1, 1, use_bias=False, name='common_selector') + ((1.0 - selector_mask[0]) * -1e10) copy_selector = tf.layers.dense(layer1, 1, use_bias=False, name='copy_selector')+ ((1.0 - selector_mask[1]) * -1e10) entity_selector = tf.layers.dense(layer1, 1, use_bias=False, name='entity_selector') + ((1.0 - selector_mask[2]) * -1e10) common_selector = tf.exp(common_selector) copy_selector = tf.exp(copy_selector) entity_selector = tf.exp(entity_selector) model_selector_openness = model_selector_openness.write(state.time, tf.concat([common_selector, copy_selector, entity_selector], -1)) exp_sum = common_selector + copy_selector + entity_selector common_selector = common_selector / exp_sum copy_selector = copy_selector / exp_sum entity_selector = entity_selector / exp_sum if self._binary_decoding: new_common_selector = tf.where(tf.greater_equal(common_selector, entity_selector), common_selector+entity_selector, tf.zeros_like(common_selector)) new_entity_selector = tf.where(tf.greater(entity_selector, common_selector), common_selector+entity_selector, tf.zeros_like(entity_selector)) entity_probs = entity_probs * new_entity_selector common_probs = common_probs * new_common_selector copy_probs = copy_probs * copy_selector else: entity_probs = entity_probs * entity_selector common_probs = common_probs * common_selector copy_probs = copy_probs * copy_selector cell_output = tf.concat([common_probs, copy_probs, entity_probs], -1) if self._cue_fact_mask: max_id = tf.argmax(cell_output, -1) # [batch] is_entity = tf.greater_equal(max_id, self._common_vocab_size + self._copy_vocab_size) # [batch] abs_id = tf.maximum(max_id - self._common_vocab_size - self._copy_vocab_size, 0) # [batch, fact_num] mask_used_fact = tf.one_hot(abs_id, maximium_candidate_num, on_value=1e-20, off_value=1.0) mask_used_fact = tf.where(is_entity, mask_used_fact, tf.ones_like(mask_used_fact)) decoding_mask = tf.minimum(mask_used_fact, decoding_mask) next_id = tf.to_int32(tf.argmax(cell_output, -1)) next_last_id = tf.to_int32(next_id) def safe_log(inX): return tf.log(inX + 1e-20) cell_output = safe_log((cell_output)) next_state = HGFUAttentionWrapperState( last_id=next_last_id, model_selector_openness=model_selector_openness, copy_alignments=copy_alignments, fact_alignments=fact_alignments, fact_memory_alignments=fact_memory_alignments, time=state.time + 1, cell_state=next_cell_state, decoding_mask=decoding_mask, attention=attention, attention_state=self._item_or_tuple(all_attention_states), alignments=self._item_or_tuple(all_alignments), alignment_history=self._item_or_tuple(maybe_all_histories)) return cell_output, next_state
def call(self, inputs, state): """Perform a step of attention-wrapped RNN. - Step 1: Mix the `inputs` and previous step's `attention` output via `cell_input_fn`. - Step 2: Call the wrapped `cell` with this input and its previous state. - Step 3: Score the cell's output with `attention_mechanism`. - Step 4: Calculate the alignments by passing the score through the `normalizer`. - Step 5: Calculate the context vector as the inner product between the alignments and the attention_mechanism's values (memory). - Step 6: Calculate the attention output by concatenating the cell output and context through the attention layer (a linear layer with `attention_layer_size` outputs). Args: inputs: (Possibly nested tuple of) Tensor, the input at this time step. state: An instance of `ECMWrapperState` containing tensors from the previous time step. Returns: A tuple `(attention_or_cell_output, next_state)`, where: - `attention_or_cell_output` depending on `output_attention`. - `next_state` is an instance of `ECMWrapperState` containing the state calculated at this time step. Raises: TypeError: If `state` is not an instance of `ECMWrapperState`. """ if not isinstance(state, ECMWrapperState): raise TypeError( "Expected state to be instance of ECMWrapperState. " "Received type %s instead." % type(state)) # Step 1: Calculate the true inputs to the cell based on the # previous attention value. # ===================================================================== r_cell_state = state.cell_state # 首先取出上一个状态中的cell_state r_cell_state = r_cell_state[-1] # 取cell_state的最后一层的状态 if isinstance(r_cell_state, LSTMStateTuple): # 如果是lstm就将c和h拼接起来 print('read gate concat LSTMState C and H') r_cell_state = tf.concat([r_cell_state.c, r_cell_state.h], axis=-1) read_inputs = tf.concat([inputs, r_cell_state, state.attention], axis=-1) # internal_memory的read_gate inputs M_read = self._read_internal_memory( state.internal_memory, read_inputs) # internal_memory的读取过程 cell_inputs = tf.concat( [inputs, state.attention, self._emo_cat_embs, M_read], axis=-1) # 当前时序rnn_cell的输入 cell_state = state.cell_state cell_output, next_cell_state = self._cell(cell_inputs, cell_state) next_cell_state_to_write = next_cell_state[-1] # 取最后一层的状态 if isinstance(next_cell_state_to_write, LSTMStateTuple): print('write gate concat LSTMState C and H') next_cell_state_to_write = tf.concat( [next_cell_state_to_write.c, next_cell_state_to_write.h], axis=-1) new_M_emo = self._write_internal_memory( state.internal_memory, next_cell_state_to_write) # internal_memory的写入过程 # ======================================================================= cell_batch_size = (cell_output.shape[0].value or array_ops.shape(cell_output)[0]) error_message = ( "When applying ECMWrapper %s: " % self.name + "Non-matching batch sizes between the memory " "(encoder output) and the query (decoder output). Are you using " "the BeamSearchDecoder? You may need to tile your memory input via " "the tf.contrib.seq2seq.tile_batch function with argument " "multiple=beam_width.") with ops.control_dependencies( self._batch_size_checks(cell_batch_size, error_message)): cell_output = array_ops.identity(cell_output, name="checked_cell_output") if self._is_multi: previous_alignments = state.alignments previous_alignment_history = state.alignment_history else: previous_alignments = [state.alignments] previous_alignment_history = [state.alignment_history] all_alignments = [] all_attentions = [] all_histories = [] for i, attention_mechanism in enumerate(self._attention_mechanisms): attention, alignments = _compute_attention( attention_mechanism, cell_output, previous_alignments[i], self._attention_layers[i] if self._attention_layers else None) alignment_history = previous_alignment_history[i].write( state.time, alignments) if self._alignment_history else () all_alignments.append(alignments) all_histories.append(alignment_history) all_attentions.append(attention) attention = array_ops.concat(all_attentions, 1) # ======================================================== # 新的状态传递给下一个时序 next_state = ECMWrapperState( time=state.time + 1, cell_state=next_cell_state, attention=attention, internal_memory=new_M_emo, alignments=self._item_or_tuple(all_alignments), alignment_history=self._item_or_tuple(all_histories)) # ========================================================= if self._output_attention: return attention, next_state else: return cell_output, next_state
def call(self, inputs, state): """Perform a step of attention-wrapped RNN. - Step 1: Mix the `inputs` and previous step's `attention` output via `cell_input_fn`. - Step 2: Call the wrapped `cell` with this input and its previous state. - Step 3: Score the cell's output with `attention_mechanism`. - Step 4: Calculate the alignments by passing the score through the `normalizer`. - Step 5: Calculate the context vector as the inner product between the alignments and the attention_mechanism's values (memory). - Step 6: Calculate the attention output by concatenating the cell output and context through the attention layer (a linear layer with `attention_layer_size` outputs). Args: inputs: (Possibly nested tuple of) Tensor, the input at this time step. state: An instance of `AttentionWrapperState` containing tensors from the previous time step. Returns: A tuple `(attention_or_cell_output, next_state)`, where: - `attention_or_cell_output` depending on `output_attention`. - `next_state` is an instance of `AttentionWrapperState` containing the state calculated at this time step. Raises: TypeError: If `state` is not an instance of `AttentionWrapperState`. """ if not isinstance(state, HGFUAttentionWrapperState): raise TypeError( "Expected state to be instance of AttentionWrapperState. " "Received type %s instead." % type(state)) # Step 1: Calculate the true inputs to the cell based on the # previous attention value. std_cell_inputs = self._cell_input_fn(inputs, state.attention) cell_state = state.cell_state k_openness = state.k_openness inputs_for_cue = tf.concat([state.attention, self._cue_inputs], axis=-1) with tf.variable_scope('std_gru'): outputs_std, cell_state = self._std_cell(std_cell_inputs, cell_state) with tf.variable_scope('cue_gru'): outputs_cue, cue_cell_state = self._cue_cell( inputs_for_cue, cell_state) transformed_hy = tf.layers.dense(outputs_std, units=self._std_cell.state_size, activation=tf.nn.tanh, use_bias=False, name='FusionGate_HY') transformed_hw = tf.layers.dense(outputs_cue, units=self._cue_cell.state_size, activation=tf.nn.tanh, use_bias=False, name='FusionGate_HW') k = tf.layers.dense(tf.concat([transformed_hy, transformed_hw], -1), units=self._cue_cell.state_size, activation=tf.nn.sigmoid, use_bias=False, name='FusionGate_k') if self.k_openness_history: k_his = tf.reduce_mean(k, axis=-1) k_openness = k_openness.write(state.time, k_his) cell_output = k * outputs_std + (1.0 - k) * outputs_cue next_cell_state = k * cell_state + (1.0 - k) * cue_cell_state cell_batch_size = (cell_output.shape[0].value or array_ops.shape(cell_output)[0]) error_message = ( "When applying AttentionWrapper %s: " % self.name + "Non-matching batch sizes between the memory " "(encoder output) and the query (decoder output). Are you using " "the BeamSearchDecoder? You may need to tile your memory input via " "the tf.contrib.seq2seq.tile_batch function with argument " "multiple=beam_width.") with ops.control_dependencies( self._batch_size_checks(cell_batch_size, error_message)): cell_output = array_ops.identity(cell_output, name="checked_cell_output") if self._is_multi: previous_attention_state = state.attention_state previous_alignment_history = state.alignment_history else: previous_attention_state = [state.attention_state] previous_alignment_history = [state.alignment_history] all_alignments = [] all_attentions = [] all_attention_states = [] maybe_all_histories = [] for i, attention_mechanism in enumerate(self._attention_mechanisms): attention, alignments, next_attention_state = attention_wrapper._compute_attention( attention_mechanism, cell_output, previous_attention_state[i], self._attention_layers[i] if self._attention_layers else None) alignment_history = previous_alignment_history[i].write( state.time, alignments) if self._alignment_history else () all_attention_states.append(next_attention_state) all_alignments.append(alignments) all_attentions.append(attention) maybe_all_histories.append(alignment_history) attention = array_ops.concat(all_attentions, 1) next_state = HGFUAttentionWrapperState( k_openness=k_openness, time=state.time + 1, cell_state=next_cell_state, attention=attention, attention_state=self._item_or_tuple(all_attention_states), alignments=self._item_or_tuple(all_alignments), alignment_history=self._item_or_tuple(maybe_all_histories)) if self._output_attention: return attention, next_state else: return cell_output, next_state
def call(self, inputs, state): """Perform a step of attention-wrapped RNN. - Step 1: Mix the `inputs` and previous step's `attention` output via `cell_input_fn`. - Step 2: Call the wrapped `cell` with this input and its previous state. - Step 3: Score the cell's output with `attention_mechanism`. - Step 4: Calculate the alignments by passing the score through the `normalizer`. - Step 5: Calculate the context vector as the inner product between the alignments and the attention_mechanism's values (memory). - Step 6: Calculate the attention output by concatenating the cell output and context through the attention layer (a linear layer with `attention_layer_size` outputs). Args: inputs: (Possibly nested tuple of) Tensor, the input at this time step. state: An instance of `AttentionWrapperState` containing tensors from the previous time step. Returns: A tuple `(attention_or_cell_output, next_state)`, where: - `attention_or_cell_output` depending on `output_attention`. - `next_state` is an instance of `AttentionWrapperState` containing the state calculated at this time step. Raises: TypeError: If `state` is not an instance of `AttentionWrapperState`. """ if not isinstance(state, AttentionWrapperState): raise TypeError("Expected state to be instance of AttentionWrapperState. " "Received type %s instead." % type(state)) # Step 1: Calculate the true inputs to the cell based on the # previous attention value. cell_inputs = self._cell_input_fn(inputs, state.attention) cell_state = state.cell_state cell_output, next_cell_state = self._cell(cell_inputs, cell_state) # add the same calculation on bias bias_inputs = self._cell_input_fn(inputs, state.attention_bias) bias_state = state.bias_state bias_output, next_bias_state = self._cell(bias_inputs, bias_state) cell_batch_size = ( cell_output.shape[0].value or array_ops.shape(cell_output)[0]) error_message = ( "When applying AttentionWrapper %s: " % self.name + "Non-matching batch sizes between the memory " "(encoder output) and the query (decoder output). Are you using " "the BeamSearchDecoder? You may need to tile your memory input via " "the tf.contrib.seq2seq.tile_batch function with argument " "multiple=beam_width.") with ops.control_dependencies( self._batch_size_checks(cell_batch_size, error_message)): cell_output = array_ops.identity( cell_output, name="checked_cell_output") # add the same calculation on bias bias_output = array_ops.identity( bias_output, name="checked_bias_output") if self._is_multi: previous_attention_state = state.attention_state previous_alignment_history = state.alignment_history else: previous_attention_state = [state.attention_state] previous_alignment_history = [state.alignment_history] all_alignments = [] all_attentions = [] all_attention_states = [] maybe_all_histories = [] # calculate the multi_heads score *** score = self._my_attention_mechanism(cell_output) score = tf.transpose(score, [1, 0, 2]) for i, attention_mechanism in enumerate(self._attention_mechanisms): attention, alignments, next_attention_state = _compute_attention( attention_mechanism, score[i], previous_attention_state[i], self._attention_layers[i] if self._attention_layers else None) #*** alignment_history = previous_alignment_history[i].write( state.time, alignments) if self._alignment_history else () all_attention_states.append(next_attention_state) all_alignments.append(alignments) all_attentions.append(attention) maybe_all_histories.append(alignment_history) # add the same calculation on bias if self._is_multi_bias: previous_attention_state_bias = state.attention_state_bias previous_alignment_history_bias = state.alignment_history_bias else: previous_attention_state_bias = [state.attention_state_bias] previous_alignment_history_bias = [state.alignment_history_bias] all_alignments_bias = [] all_attentions_bias = [] all_attention_states_bias = [] maybe_all_histories_bias = [] if self._is_multi_bias: score = self._my_attention_mechanism_bias(bias_output) score = tf.transpose(score, [1, 0, 2]) temp_output = score else: temp_output = bias_output for i, attention_mechanism in enumerate(self._attention_mechanisms_bias): attention, alignments, next_attention_state = _compute_attention( attention_mechanism, temp_output[i] if self._is_multi_bias else temp_output, previous_attention_state_bias[i], self._attention_layers[i] if self._attention_layers else None) alignment_history = previous_alignment_history_bias[i].write( state.time, alignments) if self._alignment_history else () all_attention_states_bias.append(next_attention_state) all_alignments_bias.append(alignments) all_attentions_bias.append(attention) maybe_all_histories_bias.append(alignment_history) attention = array_ops.concat(all_attentions, 1) attention_bias = array_ops.concat(all_attentions_bias, 1) next_state = AttentionWrapperState( time=state.time + 1, cell_state=next_cell_state, bias_state=next_bias_state, attention=attention, attention_bias=attention_bias, attention_state=self._item_or_tuple(all_attention_states, False), attention_state_bias=self._item_or_tuple(all_attention_states_bias, True), alignments=self._item_or_tuple(all_alignments, False), alignments_bias=self._item_or_tuple(all_alignments_bias, True), alignment_history=self._item_or_tuple(maybe_all_histories, False), alignment_history_bias=self._item_or_tuple(maybe_all_histories_bias, True)) if self._output_attention: return attention_bias, next_state else: return tf.concat([cell_output,bias_output],-1), next_state
def call(self, inputs, state): """Perform a step of attention-wrapped RNN. - Step 1: Mix the `inputs` and previous step's `attention` output via `cell_input_fn`. - Step 2: Call the wrapped `cell` with this input and its previous state. - Step 3: Score the cell's output with `attention_mechanism`. - Step 4: Calculate the alignments by passing the score through the `normalizer`. - Step 5: Calculate the context vector as the inner product between the alignments and the attention_mechanism's values (memory). - Step 6: Calculate the attention output by concatenating the cell output and context through the attention layer (a linear layer with `attention_layer_size` outputs). Args: inputs: (Possibly nested tuple of) Tensor, the input at this time step. state: An instance of `AttentionWrapperState` containing tensors from the previous time step. Returns: A tuple `(attention_or_cell_output, next_state)`, where: - `attention_or_cell_output` depending on `output_attention`. - `next_state` is an instance of `AttentionWrapperState` containing the state calculated at this time step. Raises: TypeError: If `state` is not an instance of `AttentionWrapperState`. """ if not isinstance(state, CoverageAttentionWrapperState): raise TypeError( "Expected state to be instance of AttentionWrapperState. " "Received type %s instead." % type(state)) # Step 1: Calculate the true inputs to the cell based on the # previous attention value. cell_inputs = self._cell_input_fn(inputs, state.attention) cell_state = state.cell_state cell_output, next_cell_state = self._cell(cell_inputs, cell_state) cell_batch_size = (cell_output.shape[0].value or tf.shape(cell_output)[0]) error_message = ( "When applying AttentionWrapper %s: " % self.name + "Non-matching batch sizes between the memory " "(encoder output) and the query (decoder output). Are you using " "the BeamSearchDecoder? You may need to tile your memory input via " "the tf.contrib.seq2seq.tile_batch function with argument " "multiple=beam_width.") with tf.control_dependencies( self._batch_size_checks(cell_batch_size, error_message)): cell_output = tf.identity(cell_output, name="checked_cell_output") if self._is_multi: previous_alignments = state.alignments previous_alignment_history = state.alignment_history previous_coverage = state.coverages else: previous_alignments = [state.alignments] previous_alignment_history = [state.alignment_history] previous_coverage = [state.coverages] all_alignments = [] all_attentions = [] all_histories = [] all_coverages = [] n = tf.get_variable("N", [1], dtype=tf.float32, initializer=tf.constant_initializer(2)) for i, attention_mechanism in enumerate(self._attention_mechanisms): # values on attention_mechanism is hidden state from encoder # batch * atten_len * 1 fertility = n * tf.nn.sigmoid( self.fertility_layer(attention_mechanism.values)) #fertility = self.N * tf.nn.sigmoid(self.fertility_layer(attention_mechanism.values)) # coverage shape: batch * atten_len * 1 expand_coverage = tf.expand_dims(previous_coverage[i], axis=-1) pre_coverage = self.coverage_layer(expand_coverage / fertility) #pre_coverage = tf.Print(pre_coverage,[fertility[0],previous_coverage[i][0],tf.reduce_sum(previous_coverage[i][0])],message='pre_coverage') attention, alignments = _compute_attention( attention_mechanism, cell_output, pre_coverage, self._attention_layers[i] if self._attention_layers else None) alignment_history = previous_alignment_history[i].write( state.time, alignments) if self._alignment_history else () # batch * atten_len * 1 coverage = expand_coverage + tf.expand_dims(alignments, axis=-1) all_alignments.append(alignments) all_histories.append(alignment_history) all_attentions.append(attention) all_coverages.append(tf.squeeze(coverage, axis=[2])) attention = tf.concat(all_attentions, 1) next_state = CoverageAttentionWrapperState( time=state.time + 1, cell_state=next_cell_state, attention=attention, coverages=self._item_or_tuple(all_coverages), alignments=self._item_or_tuple(all_alignments), alignment_history=self._item_or_tuple(all_histories)) if self._output_attention: return attention, next_state else: return cell_output, next_state
def __init__(self, config, batch_size, embedding, encoder_input, input_len, is_training=True, ru=False): self.config = config with tf.variable_scope("encoder_input"): self.embedding = embedding self.encoder_input = encoder_input self.input_len = input_len self.batch_size = batch_size self.is_training = is_training with tf.variable_scope("encoder_rnn"): encoder_emb_inputs = tf.nn.embedding_lookup(self.embedding, self.encoder_input) def create_cell(): if self.config.RNN_CELL == 'lnlstm': cell = rnn.LayerNormBasicLSTMCell(self.config.ENC_RNN_SIZE) elif self.config.RNN_CELL == 'lstm': cell = rnn.BasicLSTMCell(self.config.ENC_RNN_SIZE) elif self.config.RNN_CELL == 'gru': cell = rnn.GRUCell(self.config.ENC_RNN_SIZE) else: logger.error('rnn_cell {} not supported'.format(self.config.RNN_CELL)) if self.is_training: cell = tf.nn.rnn_cell.DropoutWrapper(cell, output_keep_prob=self.config.DROPOUT_KEEP) return cell cell_fw = create_cell() cell_bw = create_cell() output = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, encoder_emb_inputs, dtype=tf.float32) encoder_outputs, encoder_state = output def get_last_hidden(): if self.config.RNN_CELL == 'gru': return tf.concat([encoder_state[0], encoder_state[1]], -1) else: return tf.concat([encoder_state[0][1], encoder_state[1][1]], -1) # last fw and bw hidden state if self.config.ENC_FUNC == 'mean': encoder_rnn_output = tf.reduce_mean(tf.concat(encoder_outputs, -1), 1) elif self.config.ENC_FUNC == 'lasth': encoder_rnn_output = get_last_hidden() elif self.config.ENC_FUNC in ['attn', 'attn_scale']: attn = seq2seq.LuongAttention(self.config.ENC_RNN_SIZE * 2, tf.concat(encoder_outputs, -1), self.input_len, scale=self.config.ENC_FUNC == 'attn_scale') encoder_rnn_output, _ = _compute_attention(attn, get_last_hidden(), None, None) elif self.config.ENC_FUNC in ['attn_ba', 'attn_ba_norm']: attn = seq2seq.BahdanauAttention(self.config.ENC_RNN_SIZE, tf.concat(encoder_outputs, -1), self.input_len, normalize=self.config.ENC_FUNC == 'attn_ba_norm') encoder_rnn_output, _ = _compute_attention(attn, get_last_hidden(), None, None) else: logger.error('enc_func {} not supported'.format(self.config.ENC_FUNC)) with tf.name_scope("mu"): mu = tf.layers.dense(encoder_rnn_output, self.config.LATENT_VARIABLE_SIZE, activation=tf.nn.tanh) self.mu = tf.layers.dense(mu, self.config.LATENT_VARIABLE_SIZE, activation=None) with tf.name_scope("log_var"): logvar = tf.layers.dense(encoder_rnn_output, self.config.LATENT_VARIABLE_SIZE, activation=tf.nn.tanh) self.logvar = tf.layers.dense(logvar, self.config.LATENT_VARIABLE_SIZE, activation=None) with tf.name_scope("epsilon"): epsilon = tf.random_normal((self.batch_size, self.config.LATENT_VARIABLE_SIZE), mean=0.0, stddev=1.0) with tf.name_scope("latent_variables"): if self.is_training: self.latent_variables = self.mu + (tf.exp(0.5 * self.logvar) * epsilon) else: self.latent_variables = self.mu + (tf.exp(0.5 * self.logvar) * 0)