def call(self, inputs, state): # 注意这个wrapper改变了attention操作顺序,注意与原版的区别 if not isinstance(state, seq2seq.AttentionWrapperState): raise TypeError( "Expected state to be instance of AttentionWrapperState. " "Received type %s instead." % type(state)) if self._is_multi: previous_alignments = state.alignments previous_alignment_history = state.alignment_history else: previous_alignments = [state.alignments] previous_alignment_history = [state.alignment_history] all_alignments = [] all_attentions = [] all_histories = [] all_attentions_state = [] for i, attention_mechanism in enumerate(self._attention_mechanisms): if isinstance(self._cell, rnn.LSTMCell): rnn_cell_state = state.cell_state.h else: rnn_cell_state = state.cell_state # Some of code are changed based on https://github.com/bgshih/aster/issues/2 attention, alignments, attention_state = _compute_attention( attention_mechanism, rnn_cell_state, previous_alignments[i], self._attention_layers[i] if self._attention_layers else None) alignment_history = previous_alignment_history[i].write( state.time, alignments) if self._alignment_history else () all_alignments.append(alignments) all_histories.append(alignment_history) all_attentions.append(attention) all_attentions_state.append(attention_state) attention = array_ops.concat(all_attentions, 1) cell_inputs = self._cell_input_fn(inputs, attention) cell_output, next_cell_state = self._cell(cell_inputs, state.cell_state) next_state = seq2seq.AttentionWrapperState( time=state.time + 1, cell_state=next_cell_state, attention=attention, alignments=self._item_or_tuple(all_alignments), alignment_history=self._item_or_tuple(all_histories), attention_state=self._item_or_tuple(all_attentions_state)) if self._output_attention: return attention, next_state else: return cell_output, next_state
def get_zero_state(self): batch_size, dtype = tf.shape(self.inputs)[0], self.inputs.dtype decoder_state0 = self.decoder.zero_state(batch_size, dtype) alignments = self.bahdanau( decoder_state0, self.bahdanau.initial_alignments(batch_size, dtype)) expanded_alignments = tf.expand_dims(alignments, 1) attention_mechanism_values = self.bahdanau.values context = expanded_alignments @ attention_mechanism_values context1 = tf.squeeze(context, [1]) t0 = tf.zeros([], dtype=tf.int32) return seq2seq.AttentionWrapperState(cell_state=decoder_state0, time=t0, alignments=alignments, alignment_history=(), attention=context1)
def decoding_layer(dec_embed_input, embeddings, enc_output, enc_state, vocab_size, text_length, summary_length, max_summary_length, rnn_size, vocab_to_int, keep_prob, batch_size, num_layers): for layer in range(num_layers): with tf.variable_scope('decoder_{}'.format(layer)): lstm = tf.nn.rnn_cell.LSTMCell( rnn_size, initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=2)) dec_cell = tf.nn.rnn_cell.DropoutWrapper(lstm, input_keep_prob=keep_prob) #全连接层 output_layer = Dense(vocab_size, kernel_initializer=tf.truncated_normal_initializer( mean=0.0, stddev=0.1)) attn_mech = seq.BahdanauAttention(rnn_size, enc_output, text_length, normalize=False, name='BahdanauAttention') dec_cell = seq.AttentionWrapper(cell=dec_cell, attention_mechanism=attn_mech, attention_layer_size=rnn_size) # 引入注意力机制 initial_state = seq.AttentionWrapperState( enc_state[0], _zero_state_tensors(rnn_size, batch_size, tf.float32)) with tf.variable_scope("decode"): training_logits = training_decoding_layer(dec_embed_input, summary_length, dec_cell, initial_state, output_layer, vocab_size, max_summary_length) with tf.variable_scope("decode", reuse=True): inference_logits = inference_decoding_layer( embeddings, vocab_to_int['<GO>'], vocab_to_int['<EOS>'], dec_cell, initial_state, output_layer, max_summary_length, batch_size) return training_logits, inference_logits
def call(self, inputs, state): """Perform a step of attention-wrapped RNN. - Step 1: Mix the `inputs` and previous step's `attention` output via `cell_input_fn`. - Step 2: Call the wrapped `cell` with this input and its previous state. - Step 3: Score the cell's output with `attention_mechanism`. - Step 4: Calculate the alignments by passing the score through the `normalizer`. - Step 5: Calculate the context vector as the inner product between the alignments and the attention_mechanism's values (memory). - Step 6: Calculate the attention output by concatenating the cell output and context through the attention layer (a linear layer with `attention_layer_size` outputs). Args: inputs: (Possibly nested tuple of) Tensor, the input at this time step. state: An instance of `AttentionWrapperState` containing tensors from the previous time step. Returns: A tuple `(attention_or_cell_output, next_state)`, where: - `attention_or_cell_output` depending on `output_attention`. - `next_state` is an instance of `AttentionWrapperState` containing the state calculated at this time step. Raises: TypeError: If `state` is not an instance of `AttentionWrapperState`. """ if not isinstance(state, tf.contrib.seq2seq.AttentionWrapperState): raise TypeError( "Expected state to be instance of AttentionWrapperState. " "Received type %s instead." % type(state)) # Step 1: Calculate the true inputs to the cell based on the # previous attention value. cell_inputs = self._cell_input_fn(inputs, state.attention) cell_state = state.cell_state cell_output, next_cell_state = self._cell(cell_inputs, cell_state) cell_batch_size = (cell_output.shape[0].value or tf.shape(cell_output)[0]) error_message = ( "When applying AttentionWrapper %s: " % self.name + "Non-matching batch sizes between the memory " "(encoder output) and the query (decoder output). Are you using " "the BeamSearchDecoder? You may need to tile your memory input via " "the tf.contrib.seq2seq.tile_batch function with argument " "multiple=beam_width.") with tf.control_dependencies( self._batch_size_checks(cell_batch_size, error_message)): cell_output = tf.identity(cell_output, name="checked_cell_output") if self._is_multi: previous_alignments = state.alignments previous_alignment_history = state.alignment_history else: previous_alignments = [state.alignments] previous_alignment_history = [state.alignment_history] all_alignments = [] all_attentions = [] all_histories = [] for i, attention_mechanism in enumerate(self._attention_mechanisms): if self.coverage: # if we use coverage mode, previous alignments is coverage vector # alignment history stack has shape: decoder time * batch * atten_len # convert it to coverage vector previous_alignments[i] = tf.cond( previous_alignment_history[i].size() > 0, lambda: tf.reduce_sum(tf.transpose( previous_alignment_history[i].stack(), [1, 2, 0]), axis=2), lambda: tf.zeros_like(previous_alignments[i])) # debug # previous_alignments[i] = tf.Print(previous_alignments[i],[previous_alignment_history[i].size(), tf.shape(previous_alignments[i]),previous_alignments[i]],message="atten wrapper:") attention, alignments, _ = _compute_attention( attention_mechanism, cell_output, previous_alignments[i], self._attention_layers[i] if self._attention_layers else None) alignment_history = previous_alignment_history[i].write( state.time, alignments) if self._alignment_history else () all_alignments.append(alignments) all_histories.append(alignment_history) all_attentions.append(attention) attention = tf.concat(all_attentions, 1) next_state = seq2seq.AttentionWrapperState( time=state.time + 1, cell_state=next_cell_state, attention=attention, alignments=self._item_or_tuple(all_alignments), alignment_history=self._item_or_tuple(all_histories), attention_state=state.attention_state) if self._output_attention: return attention, next_state else: return cell_output, next_state
def call(self, inputs, state): """Perform a step of attention-wrapped RNN. - Step 1: Mix the `inputs` and previous step's `attention` output via `cell_input_fn`. - Step 2: Call the wrapped `cell` with this input and its previous state. - Step 3: Score the cell's output with `attention_mechanism`. - Step 4: Calculate the alignments by passing the score through the `normalizer`. - Step 5: Calculate the context vector as the inner product between the alignments and the attention_mechanism's values (memory). - Step 6: Calculate the attention output by POOLING the cell output and context through the attention layer (a linear layer with `attention_layer_size` outputs). """ if not isinstance(state, seq2seq.AttentionWrapperState): raise TypeError("Expected state to be instance of " "AttentionWrapperState. " "Received type %s instead." % type(state)) # Step 1: Calculate the true inputs to the cell based on the # previous attention value. cell_inputs = self._cell_input_fn(inputs, state.attention) cell_state = state.cell_state cell_output, next_cell_state = self._cell(cell_inputs, cell_state) cell_batch_size = (cell_output.shape[0].value or array_ops.shape(cell_output)[0]) error_message = ( "When applying AttentionWrapper %s: " % self.name + "Non-matching batch sizes between the memory " "(encoder output) and the query (decoder output). Are you using " "the BeamSearchDecoder? You may need to tile your memory input via " "the tf.contrib.seq2seq.tile_batch function with argument " "multiple=beam_width.") with ops.control_dependencies( self._batch_size_checks(cell_batch_size, error_message)): cell_output = array_ops.identity(cell_output, name="checked_cell_output") if self._is_multi: previous_alignments = state.alignments previous_alignment_history = state.alignment_history else: previous_alignments = [state.alignments] previous_alignment_history = [state.alignment_history] all_alignments = [] all_attentions = [] all_histories = [] for i, attention_mechanism in enumerate(self._attention_mechanisms): attention, alignments = _compute_attention( attention_mechanism, cell_output, previous_alignments[i], self._attention_layers[i] if self._attention_layers else None, reuse=i > 0) alignment_history = previous_alignment_history[i].write( state.time, alignments) if self._alignment_history else () all_alignments.append(alignments) all_histories.append(alignment_history) all_attentions.append(attention) if self.pooling == 'avgpool': attention = tf.reduce_mean(tf.stack(all_attentions, axis=1), axis=1) else: raise ValueError('Unknown pooling method') next_state = seq2seq.AttentionWrapperState( time=state.time + 1, cell_state=next_cell_state, attention=attention, alignments=self._item_or_tuple(all_alignments), alignment_history=self._item_or_tuple(all_histories)) if self._output_attention: return attention, next_state else: return cell_output, next_state
def _init_encoder(self): with tf.variable_scope("Encoder") as scope: encoder_inputs = self._maybe_add_dense_layers() cell_type = self._hparams.cell_type[0 if self._data_type == 'video' else 1] batch_size = self._hparams.batch_size[0 if self._mode == 'train' else 1] if self._hparams.encoder_type == 'unidirectional': self._encoder_cells, initial_state = build_rnn_layers( cell_type=cell_type, num_units_per_layer=self._num_units_per_layer, use_dropout=self._hparams.use_dropout, dropout_probability=self._dropout_probability, mode=self._mode, as_list=True, batch_size=batch_size, dtype=self._hparams.dtype) #self._encoder_cells, initial_state = create_model(model=cell_type, # num_cells=self._num_units_per_layer, # batch_size=batch_size, # as_list=False if cell_type == 'multi_skip_lstm' else True, # learn_initial_state=True if 'skip' in cell_type else False, # use_dropout=self._hparams.use_dropout, # dropout_probability=self._dropout_probability) self._encoder_cells = maybe_list(self._encoder_cells) print(self._num_units_per_layer) print('encoder_cells', self._encoder_cells) print('AttentiveEncoder_initial_state', initial_state) #### here weird code # 1. reverse mem # self._attended_memory = tf.reverse(self._attended_memory, axis=[1]) # 2. append zeros # randval1 = tf.random.uniform(shape=[], minval=25, maxval=100, dtype=tf.int32) # randval2 = tf.random.uniform(shape=[], minval=25, maxval=100, dtype=tf.int32) # zeros_slice1 = tf.zeros([1, randval1, 256], dtype=tf.float32) # assuming we use inference on a batch size of 1 # zeros_slice2 = tf.zeros([1, randval2, 256], dtype=tf.float32) # self._attended_memory = tf.concat([zeros_slice1, self._attended_memory, zeros_slice2], axis=1) # self._attended_memory_length += randval1 + randval2 # 3. blank mem # self._attended_memory = 0* self._attended_memory # 4. mix with noise # noise = tf.random.truncated_normal(shape=tf.shape(self._attended_memory)) # noise = tf.random.uniform(shape=tf.shape(self._attended_memory)) # self._attended_memory = noise #### here stop weird code if cell_type == 'skip_lstm' and self._hparams.separate_skip_rnn: skip_cell = self._encoder_cells[0] self._encoder_cells = self._encoder_cells[1:] skip_out = tf.nn.dynamic_rnn( cell=skip_cell, inputs=encoder_inputs, sequence_length=self._inputs_len, parallel_iterations=batch_size, swap_memory=False, dtype=self._hparams.dtype, scope=scope, initial_state=skip_cell.trainable_initial_state( batch_size), ) print('skip_out', skip_out) skip_output, skip_final_state = skip_out h, updated_states = skip_output print('skip_encoder_updated_states', updated_states) cost_per_sample = self._hparams.cost_per_sample[1] budget_loss = tf.reduce_mean( tf.reduce_sum(cost_per_sample * updated_states, 1), 0) meanUpdates = tf.reduce_mean( tf.reduce_sum(updated_states, 1), 0) self.skip_infos = SkipInfoTuple(updated_states, meanUpdates, budget_loss) # Tried to remove the skipped states in the output, but this destroys the input shape #updated_states_shape = tf.shape(updated_states) #h_shape = tf.shape(h) #updated_states = tf.reshape(updated_states, [batch_size, updated_states_shape[1]]) #new_h = tf.boolean_mask(h, tf.where(updated_states == 1.)) #new_h = tf.where(updated_states == 1., h, tf.zeros(shape=h_shape)) #print('new_h', new_h) #new_h_shape = tf.shape(new_h) #new_h = tf.reshape(new_h, [batch_size, new_h_shape[0] / batch_size, new_h_shape[1]]) #self._inputs_len = (batch_size, new_h_shape[0] / batch_size) print('skip_encoder_layer_h', h) print('attended_memory', self._attended_memory) print(self._encoder_cells) attention_cells, dummy_initial_state = add_attention( cell_type='lstm' if self._hparams.separate_skip_rnn else cell_type, cells=self._encoder_cells[-1], attention_types=self._hparams.attention_type[0], num_units=self._num_units_per_layer[-1], memory=self._attended_memory, memory_len=self._attended_memory_length, mode=self._mode, dtype=self._hparams.dtype, initial_state=initial_state[-1] if (isinstance(initial_state, tuple) and not isinstance(initial_state, SkipLSTMStateTuple)) else initial_state, batch_size=tf.shape(self._inputs_len), write_attention_alignment=self._hparams. write_attention_alignment, fusion_type='linear_fusion', ) print('AttentiveEncoder_initial_state2', initial_state) if isinstance(initial_state, tuple) and not isinstance( initial_state, SkipLSTMStateTuple): initial_state = list(initial_state) initial_state[-1] = dummy_initial_state initial_state = tuple(initial_state) else: initial_state = dummy_initial_state self._encoder_cells[-1] = attention_cells # initial_state = self._encoder_cells.get_initial_state(inputs=None, batch_size=batch_size, dtype=self._hparams.dtype) #initial_state = [] #for i, cell in enumerate(self._encoder_cells): # print(i, cell) # if isinstance(cell, SkipLSTMCell): # #pass # #with tf.variable_scope(f'layer_{i}') as init_scope: # initial_state.append(cell.zero_state(batch_size, dtype=self._hparams.dtype)) # else: # with tf.variable_scope(f'layer_{i}'): # initial_state.append( # cell.get_initial_state(inputs=None, batch_size=batch_size, dtype=self._hparams.dtype)) #initial_state = tuple(initial_state) print('AttentiveEncoder_encoder_cells', self._encoder_cells) self._encoder_cells = maybe_multirnn(self._encoder_cells) print('AttentiveEncoder_encoder_cells', self._encoder_cells) print('AttentiveEncoder_encoder_inputs', encoder_inputs) print('AttentiveEncoder_inputs_len', self._inputs_len) #initial_state = self._encoder_cells.get_initial_state(batch_size=batch_size, dtype=self._hparams.dtype) #print('AttentiveEncoder_initial_state_final', initial_state) out = tf.nn.dynamic_rnn( cell=self._encoder_cells, inputs=encoder_inputs if not self._hparams.separate_skip_rnn else h, sequence_length=self._inputs_len, parallel_iterations=self._hparams. batch_size[0 if self._mode == 'train' else 1], swap_memory=False, dtype=self._hparams.dtype, scope=scope, initial_state=initial_state, ) self._encoder_outputs, self._encoder_final_state = out print("AttentiveEncoder_dynamic_rnn_out", self._encoder_outputs) print("AttentiveEncoder_dynamic_rnn_fs", self._encoder_final_state) if not self._hparams.separate_skip_rnn and 'skip' in cell_type: self._encoder_outputs, updated_states = self._encoder_outputs cost_per_sample = self._hparams.cost_per_sample[1] budget_loss = tf.reduce_mean( tf.reduce_sum(cost_per_sample * updated_states, 1), 0) meanUpdates = tf.reduce_mean( tf.reduce_sum(updated_states, 1), 0) self.skip_infos = SkipInfoTuple(updated_states, meanUpdates, budget_loss) if isinstance(self._encoder_final_state, tuple) and not isinstance( self._encoder_final_state, seq2seq.AttentionWrapperState): self._encoder_final_state = self._encoder_final_state[ -1] print('AttentiveEncoder_final_state_inBetween', self._encoder_final_state) if isinstance(self._encoder_final_state, seq2seq.AttentionWrapperState): cell_state = self._encoder_final_state.cell_state try: cell_state = [ LSTMStateTuple(cs.c, cs.h) for cs in cell_state ] except: cell_state = LSTMStateTuple( cell_state.c, cell_state.h) self._encoder_final_state = seq2seq.AttentionWrapperState( cell_state, self._encoder_final_state.attention, self._encoder_final_state.time, self._encoder_final_state.alignments, self._encoder_final_state.alignment_history, self._encoder_final_state.attention_state) print('AttentiveEncoder_final_state', self._encoder_final_state) if self._hparams.write_attention_alignment is True: # self.weights_summary = self._encoder_final_state[-1].attention_weight_history.stack() self.attention_summary, self.attention_alignment = self._create_attention_alignments_summary( maybe_list(self._encoder_final_state)[-1])