def zero_state(self, batch_size, dtype):
        with tf.name_scope(type(self).__name__ + "ZeroState",
                           values=[batch_size]):
            if self._initial_cell_state is not None:
                cell_state = self._initial_cell_state
            else:
                cell_state = self._cell.zero_state(batch_size, dtype)
            error_message = (
                "When calling zero_state of AttentionWrapper %s: " %
                self._base_name +
                "Non-matching batch sizes between the memory "
                "(encoder output) and the requested batch size.    Are you using "
                "the BeamSearchDecoder?    If so, make sure your encoder output has "
                "been tiled to beam_width via tf.contrib.seq2seq.tile_batch, and "
                "the batch_size= argument passed to zero_state is "
                "batch_size * beam_width.")
            with tf.control_dependencies(
                    self._batch_size_checks(batch_size, error_message)):
                cell_state = nest.map_structure(
                    lambda s: tf.identity(s, name="checked_cell_state"),
                    cell_state)

            return AttentionWrapperState(
                cell_state=cell_state,
                time=tf.zeros([], dtype=tf.int32),
                attention=_zero_state_tensors(self._attention_layer_size,
                                              batch_size, dtype),
                alignments=self._item_or_tuple(
                    attention_mechanism.initial_alignments(batch_size, dtype)
                    for attention_mechanism in self._attention_mechanisms),
                alignment_history=self._item_or_tuple(
                    tf.TensorArray(dtype=dtype, size=0, dynamic_size=True
                                   ) if self._alignment_history else ()
                    for _ in self._attention_mechanisms))
 def state_size(self):
     return AttentionWrapperState(
         cell_state=self._cell.state_size,
         time=tf.TensorShape([]),
         attention=self._attention_layer_size,
         alignments=self._item_or_tuple(
             a.alignments_size for a in self._attention_mechanisms),
         alignment_history=self._item_or_tuple((
         ) for _ in self._attention_mechanisms))  # sometimes a TensorArray
 def zero_state(self, batch_size, dtype):
     """Return an initial (zero) state tuple for this `AttentionWrapper`.
     **NOTE** Please see the initializer documentation for details of how
     to call `zero_state` if using an `AttentionWrapper` with a
     `BeamSearchDecoder`.
     Args:
       batch_size: `0D` integer tensor: the batch size.
       dtype: The internal state data type.
     Returns:
       An `AttentionWrapperState` tuple containing zeroed out tensors and,
       possibly, empty `TensorArray` objects.
     Raises:
       ValueError: (or, possibly at runtime, InvalidArgument), if
         `batch_size` does not match the output size of the encoder passed
         to the wrapper object at initialization time.
     """
     with tf.name_scope(type(self).__name__ + "ZeroState",
                        values=[batch_size]):
         if self._initial_cell_state is not None:
             cell_state = self._initial_cell_state
         else:
             cell_state = self._cell.zero_state(batch_size, dtype)
         error_message = (
             "When calling zero_state of AttentionWrapper %s: " %
             self._base_name +
             "Non-matching batch sizes between the memory "
             "(encoder output) and the requested batch size.  Are you using "
             "the BeamSearchDecoder?  If so, make sure your encoder output has "
             "been tiled to beam_width via tf.contrib.seq2seq.tile_batch, and "
             "the batch_size= argument passed to zero_state is "
             "batch_size * beam_width.")
         with tf.control_dependencies(
                 self._batch_size_checks(batch_size, error_message)):
             cell_state = nest.map_structure(
                 lambda s: tf.identity(s, name="checked_cell_state"),
                 cell_state)
         initial_alignments = [
             attention_mechanism.initial_alignments(batch_size, dtype)
             for attention_mechanism in self._attention_mechanisms
         ]
         return AttentionWrapperState(
             cell_state=cell_state,
             time=tf.zeros([], dtype=tf.int32),
             attention=_zero_state_tensors(self._attention_layer_size,
                                           batch_size, dtype),
             alignments=self._item_or_tuple(initial_alignments),
             attention_state=self._item_or_tuple(
                 attention_mechanism.initial_state(batch_size, dtype)
                 for attention_mechanism in self._attention_mechanisms),
             alignment_history=self._item_or_tuple(
                 tf.TensorArray(dtype,
                                size=0,
                                dynamic_size=True,
                                element_shape=alignment.shape) if self.
                 _alignment_history else ()
                 for alignment in initial_alignments))
 def state_size(self):
     """The `state_size` property of `AttentionWrapper`.
     Returns:
       An `AttentionWrapperState` tuple containing shapes used by this object.
     """
     return AttentionWrapperState(
         cell_state=self._cell.state_size,
         time=tf.TensorShape([]),
         attention=self._attention_layer_size,
         alignments=self._item_or_tuple(
             a.alignments_size for a in self._attention_mechanisms),
         attention_state=self._item_or_tuple(
             a.state_size for a in self._attention_mechanisms),
         alignment_history=self._item_or_tuple(
             a.alignments_size if self._alignment_history else () for a in
             self._attention_mechanisms))  # sometimes a TensorArray
Exemplo n.º 5
0
    def call(self, inputs, state):
        """Perform a step of attention-wrapped RNN.

        - Step 1: Mix the `inputs` and previous step's `attention` output via
          `cell_input_fn`.
        - Step 2: Call the wrapped `cell` with this input and its previous state.
        - Step 3: Score the cell's output with `attention_mechanism`.
        - Step 4: Calculate the alignments by passing the score through the
          `normalizer`.
        - Step 5: Calculate the context vector as the inner product between the
          alignments and the attention_mechanism's values (memory).
        - Step 6: Calculate the attention output by concatenating the cell output
          and context through the attention layer (a linear layer with
          `attention_layer_size` outputs).

        Args:
          inputs: (Possibly nested tuple of) Tensor, the input at this time step.
          state: An instance of `AttentionWrapperState` containing
            tensors from the previous time step.

        Returns:
          A tuple `(attention_or_cell_output, next_state)`, where:

          - `attention_or_cell_output` depending on `output_attention`.
          - `next_state` is an instance of `AttentionWrapperState`
             containing the state calculated at this time step.

        Raises:
          TypeError: If `state` is not an instance of `AttentionWrapperState`.
        """
        if not isinstance(state, AttentionWrapperState):
            raise TypeError(
                "Expected state to be instance of AttentionWrapperState. "
                "Received type %s instead." % type(state))

        # Step 1: Calculate the true inputs to the cell based on the
        # previous attention value.
        cell_inputs = self._cell_input_fn(inputs, state.attention)
        cell_state = state.cell_state
        cell_output, next_cell_state = self._cell(cell_inputs, cell_state)

        cell_batch_size = (cell_output.shape[0].value
                           or tf.shape(cell_output)[0])
        error_message = (
            "When applying AttentionWrapper %s: " % self.name +
            "Non-matching batch sizes between the memory "
            "(encoder output) and the query (decoder output).  Are you using "
            "the BeamSearchDecoder?  You may need to tile your memory input via "
            "the tf.contrib.seq2seq.tile_batch function with argument "
            "multiple=beam_width.")
        with tf.control_dependencies(
                self._batch_size_checks(cell_batch_size, error_message)):
            cell_output = tf.identity(cell_output, name="checked_cell_output")

        if self._is_multi:
            previous_attention_state = state.attention_state
            previous_alignment_history = state.alignment_history
        else:
            previous_attention_state = [state.attention_state]
            previous_alignment_history = [state.alignment_history]

        all_alignments = []
        all_attentions = []
        all_attention_states = []
        maybe_all_histories = []
        for i, attention_mechanism in enumerate(self._attention_mechanisms):
            # Note: This is the only modification hacked into the attention wrapper to support
            # monotonic Luong attention.
            attention_mechanism.time = state.time

            attention, alignments, next_attention_state = _luong_local_compute_attention(
                attention_mechanism, cell_output, previous_attention_state[i],
                self._attention_layers[i] if self._attention_layers else None)
            alignment_history = previous_alignment_history[i].write(
                state.time, alignments) if self._alignment_history else ()

            all_attention_states.append(next_attention_state)
            all_alignments.append(alignments)
            all_attentions.append(attention)
            maybe_all_histories.append(alignment_history)

        attention = tf.concat(all_attentions, 1)
        next_state = AttentionWrapperState(
            time=state.time + 1,
            cell_state=next_cell_state,
            attention=attention,
            attention_state=self._item_or_tuple(all_attention_states),
            alignments=self._item_or_tuple(all_alignments),
            alignment_history=self._item_or_tuple(maybe_all_histories))

        if self._output_attention:
            return attention, next_state
        else:
            return cell_output, next_state
Exemplo n.º 6
0
    def call(self, inputs, state):
        # Step 1: Calculate the true inputs to the cell based on the
        # previous attention value.
        cell_inputs = self._cell_input_fn(inputs, state.attention)
        cell_state = state.cell_state
        cell_output, next_cell_state = self._cell(cell_inputs, cell_state)

        cell_batch_size = (cell_output.shape[0].value
                           or array_ops.shape(cell_output)[0])
        error_message = (
            "When applying AttentionWrapper %s: " % self.name +
            "Non-matching batch sizes between the memory "
            "(encoder output) and the query (decoder output).  Are you using "
            "the BeamSearchDecoder?  You may need to tile your memory input via "
            "the tf.contrib.seq2seq.tile_batch function with argument "
            "multiple=beam_width.")
        with ops.control_dependencies([
                check_ops.assert_equal(cell_batch_size,
                                       self._attention_mechanism.batch_size,
                                       message=error_message)
        ]):
            cell_output = array_ops.identity(cell_output,
                                             name="checked_cell_output")

        multi_context = []
        multi_alignments = []
        prev_alignments = self._attention_mechanism.separate_alignments(
            state.alignments)  # list of (batch_size, alignments_size)
        for attention_mechanism, prev_a in izip(
                self._attention_mechanism.attention_mechanisms,
                prev_alignments):
            alignments = attention_mechanism(cell_output,
                                             previous_alignments=prev_a)

            # Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time]
            expanded_alignments = array_ops.expand_dims(alignments, 1)
            # Context is the inner product of alignments and values along the
            # memory time dimension.
            # alignments shape is
            #   [batch_size, 1, memory_time]
            # attention_mechanism.values shape is
            #   [batch_size, memory_time, attention_mechanism.num_units]
            # the batched matmul is over memory_time, so the output shape is
            #   [batch_size, 1, attention_mechanism.num_units].
            # we then squeeze out the singleton dim.
            attention_mechanism_values = attention_mechanism.values
            context = math_ops.matmul(expanded_alignments,
                                      attention_mechanism_values)
            context = array_ops.squeeze(context, [1])

            multi_context.append(context)
            multi_alignments.append(alignments)

        # Combine multiple context
        context = tf.concat(multi_context, axis=1)
        with tf.variable_scope('CombineContext'):
            context = tf.layers.dense(context,
                                      self._multi_attention_size,
                                      use_bias=False,
                                      activation=tf.nn.relu)

        # Combine alignments
        alignments = self._attention_mechanism.combine_alignments(
            multi_alignments)  # (batch_size, \sum_{m} alignments_size_m)

        if self._attention_layer is not None:
            attention = self._attention_layer(
                array_ops.concat([cell_output, context], 1))
        else:
            attention = context

        if self._alignment_history:
            alignment_history = state.alignment_history.write(
                state.time, alignments)
        else:
            alignment_history = ()

        next_state = AttentionWrapperState(time=state.time + 1,
                                           cell_state=next_cell_state,
                                           attention=attention,
                                           alignments=alignments,
                                           alignment_history=alignment_history)

        if self._output_attention:
            return attention, next_state
        else:
            return cell_output, next_state