def zero_state(self, batch_size, dtype): with tf.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): if self._initial_cell_state is not None: cell_state = self._initial_cell_state else: cell_state = self._cell.zero_state(batch_size, dtype) error_message = ( "When calling zero_state of AttentionWrapper %s: " % self._base_name + "Non-matching batch sizes between the memory " "(encoder output) and the requested batch size. Are you using " "the BeamSearchDecoder? If so, make sure your encoder output has " "been tiled to beam_width via tf.contrib.seq2seq.tile_batch, and " "the batch_size= argument passed to zero_state is " "batch_size * beam_width.") with tf.control_dependencies( self._batch_size_checks(batch_size, error_message)): cell_state = nest.map_structure( lambda s: tf.identity(s, name="checked_cell_state"), cell_state) return AttentionWrapperState( cell_state=cell_state, time=tf.zeros([], dtype=tf.int32), attention=_zero_state_tensors(self._attention_layer_size, batch_size, dtype), alignments=self._item_or_tuple( attention_mechanism.initial_alignments(batch_size, dtype) for attention_mechanism in self._attention_mechanisms), alignment_history=self._item_or_tuple( tf.TensorArray(dtype=dtype, size=0, dynamic_size=True ) if self._alignment_history else () for _ in self._attention_mechanisms))
def state_size(self): return AttentionWrapperState( cell_state=self._cell.state_size, time=tf.TensorShape([]), attention=self._attention_layer_size, alignments=self._item_or_tuple( a.alignments_size for a in self._attention_mechanisms), alignment_history=self._item_or_tuple(( ) for _ in self._attention_mechanisms)) # sometimes a TensorArray
def zero_state(self, batch_size, dtype): """Return an initial (zero) state tuple for this `AttentionWrapper`. **NOTE** Please see the initializer documentation for details of how to call `zero_state` if using an `AttentionWrapper` with a `BeamSearchDecoder`. Args: batch_size: `0D` integer tensor: the batch size. dtype: The internal state data type. Returns: An `AttentionWrapperState` tuple containing zeroed out tensors and, possibly, empty `TensorArray` objects. Raises: ValueError: (or, possibly at runtime, InvalidArgument), if `batch_size` does not match the output size of the encoder passed to the wrapper object at initialization time. """ with tf.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): if self._initial_cell_state is not None: cell_state = self._initial_cell_state else: cell_state = self._cell.zero_state(batch_size, dtype) error_message = ( "When calling zero_state of AttentionWrapper %s: " % self._base_name + "Non-matching batch sizes between the memory " "(encoder output) and the requested batch size. Are you using " "the BeamSearchDecoder? If so, make sure your encoder output has " "been tiled to beam_width via tf.contrib.seq2seq.tile_batch, and " "the batch_size= argument passed to zero_state is " "batch_size * beam_width.") with tf.control_dependencies( self._batch_size_checks(batch_size, error_message)): cell_state = nest.map_structure( lambda s: tf.identity(s, name="checked_cell_state"), cell_state) initial_alignments = [ attention_mechanism.initial_alignments(batch_size, dtype) for attention_mechanism in self._attention_mechanisms ] return AttentionWrapperState( cell_state=cell_state, time=tf.zeros([], dtype=tf.int32), attention=_zero_state_tensors(self._attention_layer_size, batch_size, dtype), alignments=self._item_or_tuple(initial_alignments), attention_state=self._item_or_tuple( attention_mechanism.initial_state(batch_size, dtype) for attention_mechanism in self._attention_mechanisms), alignment_history=self._item_or_tuple( tf.TensorArray(dtype, size=0, dynamic_size=True, element_shape=alignment.shape) if self. _alignment_history else () for alignment in initial_alignments))
def state_size(self): """The `state_size` property of `AttentionWrapper`. Returns: An `AttentionWrapperState` tuple containing shapes used by this object. """ return AttentionWrapperState( cell_state=self._cell.state_size, time=tf.TensorShape([]), attention=self._attention_layer_size, alignments=self._item_or_tuple( a.alignments_size for a in self._attention_mechanisms), attention_state=self._item_or_tuple( a.state_size for a in self._attention_mechanisms), alignment_history=self._item_or_tuple( a.alignments_size if self._alignment_history else () for a in self._attention_mechanisms)) # sometimes a TensorArray
def call(self, inputs, state): """Perform a step of attention-wrapped RNN. - Step 1: Mix the `inputs` and previous step's `attention` output via `cell_input_fn`. - Step 2: Call the wrapped `cell` with this input and its previous state. - Step 3: Score the cell's output with `attention_mechanism`. - Step 4: Calculate the alignments by passing the score through the `normalizer`. - Step 5: Calculate the context vector as the inner product between the alignments and the attention_mechanism's values (memory). - Step 6: Calculate the attention output by concatenating the cell output and context through the attention layer (a linear layer with `attention_layer_size` outputs). Args: inputs: (Possibly nested tuple of) Tensor, the input at this time step. state: An instance of `AttentionWrapperState` containing tensors from the previous time step. Returns: A tuple `(attention_or_cell_output, next_state)`, where: - `attention_or_cell_output` depending on `output_attention`. - `next_state` is an instance of `AttentionWrapperState` containing the state calculated at this time step. Raises: TypeError: If `state` is not an instance of `AttentionWrapperState`. """ if not isinstance(state, AttentionWrapperState): raise TypeError( "Expected state to be instance of AttentionWrapperState. " "Received type %s instead." % type(state)) # Step 1: Calculate the true inputs to the cell based on the # previous attention value. cell_inputs = self._cell_input_fn(inputs, state.attention) cell_state = state.cell_state cell_output, next_cell_state = self._cell(cell_inputs, cell_state) cell_batch_size = (cell_output.shape[0].value or tf.shape(cell_output)[0]) error_message = ( "When applying AttentionWrapper %s: " % self.name + "Non-matching batch sizes between the memory " "(encoder output) and the query (decoder output). Are you using " "the BeamSearchDecoder? You may need to tile your memory input via " "the tf.contrib.seq2seq.tile_batch function with argument " "multiple=beam_width.") with tf.control_dependencies( self._batch_size_checks(cell_batch_size, error_message)): cell_output = tf.identity(cell_output, name="checked_cell_output") if self._is_multi: previous_attention_state = state.attention_state previous_alignment_history = state.alignment_history else: previous_attention_state = [state.attention_state] previous_alignment_history = [state.alignment_history] all_alignments = [] all_attentions = [] all_attention_states = [] maybe_all_histories = [] for i, attention_mechanism in enumerate(self._attention_mechanisms): # Note: This is the only modification hacked into the attention wrapper to support # monotonic Luong attention. attention_mechanism.time = state.time attention, alignments, next_attention_state = _luong_local_compute_attention( attention_mechanism, cell_output, previous_attention_state[i], self._attention_layers[i] if self._attention_layers else None) alignment_history = previous_alignment_history[i].write( state.time, alignments) if self._alignment_history else () all_attention_states.append(next_attention_state) all_alignments.append(alignments) all_attentions.append(attention) maybe_all_histories.append(alignment_history) attention = tf.concat(all_attentions, 1) next_state = AttentionWrapperState( time=state.time + 1, cell_state=next_cell_state, attention=attention, attention_state=self._item_or_tuple(all_attention_states), alignments=self._item_or_tuple(all_alignments), alignment_history=self._item_or_tuple(maybe_all_histories)) if self._output_attention: return attention, next_state else: return cell_output, next_state
def call(self, inputs, state): # Step 1: Calculate the true inputs to the cell based on the # previous attention value. cell_inputs = self._cell_input_fn(inputs, state.attention) cell_state = state.cell_state cell_output, next_cell_state = self._cell(cell_inputs, cell_state) cell_batch_size = (cell_output.shape[0].value or array_ops.shape(cell_output)[0]) error_message = ( "When applying AttentionWrapper %s: " % self.name + "Non-matching batch sizes between the memory " "(encoder output) and the query (decoder output). Are you using " "the BeamSearchDecoder? You may need to tile your memory input via " "the tf.contrib.seq2seq.tile_batch function with argument " "multiple=beam_width.") with ops.control_dependencies([ check_ops.assert_equal(cell_batch_size, self._attention_mechanism.batch_size, message=error_message) ]): cell_output = array_ops.identity(cell_output, name="checked_cell_output") multi_context = [] multi_alignments = [] prev_alignments = self._attention_mechanism.separate_alignments( state.alignments) # list of (batch_size, alignments_size) for attention_mechanism, prev_a in izip( self._attention_mechanism.attention_mechanisms, prev_alignments): alignments = attention_mechanism(cell_output, previous_alignments=prev_a) # Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time] expanded_alignments = array_ops.expand_dims(alignments, 1) # Context is the inner product of alignments and values along the # memory time dimension. # alignments shape is # [batch_size, 1, memory_time] # attention_mechanism.values shape is # [batch_size, memory_time, attention_mechanism.num_units] # the batched matmul is over memory_time, so the output shape is # [batch_size, 1, attention_mechanism.num_units]. # we then squeeze out the singleton dim. attention_mechanism_values = attention_mechanism.values context = math_ops.matmul(expanded_alignments, attention_mechanism_values) context = array_ops.squeeze(context, [1]) multi_context.append(context) multi_alignments.append(alignments) # Combine multiple context context = tf.concat(multi_context, axis=1) with tf.variable_scope('CombineContext'): context = tf.layers.dense(context, self._multi_attention_size, use_bias=False, activation=tf.nn.relu) # Combine alignments alignments = self._attention_mechanism.combine_alignments( multi_alignments) # (batch_size, \sum_{m} alignments_size_m) if self._attention_layer is not None: attention = self._attention_layer( array_ops.concat([cell_output, context], 1)) else: attention = context if self._alignment_history: alignment_history = state.alignment_history.write( state.time, alignments) else: alignment_history = () next_state = AttentionWrapperState(time=state.time + 1, cell_state=next_cell_state, attention=attention, alignments=alignments, alignment_history=alignment_history) if self._output_attention: return attention, next_state else: return cell_output, next_state