Esempio n. 1
0
 def zero_state(self, batch_size, dtype):
     """Return an initial (zero) state tuple for this `AttentionWrapper`.
     **NOTE** Please see the initializer documentation for details of how
     to call `zero_state` if using an `AttentionWrapper` with a
     `BeamSearchDecoder`.
     Args:
       batch_size: `0D` integer tensor: the batch size.
       dtype: The internal state data type.
     Returns:
       An `AttentionWrapperState` tuple containing zeroed out tensors and,
       possibly, empty `TensorArray` objects.
     Raises:
       ValueError: (or, possibly at runtime, InvalidArgument), if
         `batch_size` does not match the output size of the encoder passed
         to the wrapper object at initialization time.
     """
     with ops.name_scope(type(self).__name__ + "ZeroState",
                         values=[batch_size]):
         print('SkipAttentionWrapper_cell', self._cell)
         print('SkipAttentionWrapper_initial_cell_state',
               self._initial_cell_state)
         if self._initial_cell_state is not None:
             cell_state = self._initial_cell_state
         else:
             cell_state = self._cell.zero_state(batch_size, dtype)
         print('SkipAttentionWrapper_cell_state', cell_state)
         error_message = (
             "When calling zero_state of AttentionWrapper %s: " %
             self._base_name +
             "Non-matching batch sizes between the memory "
             "(encoder output) and the requested batch size.  Are you using "
             "the BeamSearchDecoder?  If so, make sure your encoder output has "
             "been tiled to beam_width via tf.contrib.seq2seq.tile_batch, and "
             "the batch_size= argument passed to zero_state is "
             "batch_size * beam_width.")
         with ops.control_dependencies(
                 self._batch_size_checks(batch_size, error_message)):
             cell_state = nest.map_structure(
                 lambda s: array_ops.identity(s, name="checked_cell_state"),
                 cell_state)
         initial_alignments = [
             attention_mechanism.initial_alignments(batch_size, dtype)
             for attention_mechanism in self._attention_mechanisms
         ]
         return AttentionWrapperState(
             cell_state=cell_state,
             time=array_ops.zeros([], dtype=dtypes.int32),
             attention=_zero_state_tensors(self._attention_layer_size,
                                           batch_size, dtype),
             alignments=self._item_or_tuple(initial_alignments),
             attention_state=self._item_or_tuple(
                 attention_mechanism.initial_state(batch_size, dtype)
                 for attention_mechanism in self._attention_mechanisms),
             alignment_history=self._item_or_tuple(
                 tensor_array_ops.TensorArray(dtype,
                                              size=0,
                                              dynamic_size=True,
                                              element_shape=alignment.shape)
                 if self._alignment_history else ()
                 for alignment in initial_alignments))
Esempio n. 2
0
    def zero_state(self, batch_size, dtype):
        with tf.name_scope(type(self).__name__ + "ZeroState",
                           values=[batch_size]):
            if self._initial_cell_state is not None:
                cell_state = self._initial_cell_state
            else:
                cell_state = self._cell.zero_state(batch_size, dtype)
            error_message = (
                "zero_state of AttentionWrapper %s: " % self._base_name +
                "Non-matching batch sizes between the memory "
                "(encoder output) and the requested batch size.")
            with tf.control_dependencies([
                    tf.assert_equal(batch_size,
                                    self._attention_mechanism.batch_size,
                                    message=error_message)
            ]):
                cell_state = nest.map_structure(
                    lambda s: tf.identity(s, name="checked_cell_state"),
                    cell_state)
            alignment_history = ()

            _zero_state_tensors = rnn_cell_impl._zero_state_tensors
            return AttentionWrapperState(cell_state=cell_state,
                                         time=tf.zeros([], dtype=tf.int32),
                                         attention=_zero_state_tensors(
                                             self._attention_size, batch_size,
                                             dtype),
                                         alignment_history=alignment_history)
Esempio n. 3
0
 def shape(self):
     return AttentionWrapperState(
         cell_state=self._cell.shape,
         attention=tf.TensorShape([None, self._attention_size]),
         time=tf.TensorShape(None),
         alignments=tf.TensorShape([None, None]),
         alignment_history=())
Esempio n. 4
0
 def state_size(self):
     return AttentionWrapperState(
         cell_state=self._cell.state_size,
         attention=self._attention_size,
         time=tf.TensorShape([]),
         alignments=self._attention_mechanism.alignments_size,
         alignment_history=())
Esempio n. 5
0
    def call(self, inputs, state):
        """First computes the cell state and output in the usual way, 
        then works through the attention pipeline:
            h --> a --> c --> h_tilde
        using the naming/notation from Luong et. al, 2015.

        Args:
            inputs: `2-D` tensor with shape `[batch_size x input_size]`.
            state: An instance of `AttentionWrapperState` containing the 
                tensors from the prev timestep.
     
        Returns:
            A tuple `(attention_or_cell_output, next_state)`, where:
            - `attention_or_cell_output` depending on `output_attention`.
            - `next_state` is an instance of `DynamicAttentionWrapperState`
                containing the state calculated at this time step.
        """

        # Concatenate the previous h_tilde with inputs (input-feeding).
        cell_inputs = tf.concat([inputs, state.attention], -1)

        # 1. (hidden) Compute the hidden state (cell_output).
        cell_output, next_cell_state = self._cell(cell_inputs,
                                                  state.cell_state)

        # 2. (align) Compute the normalized alignment scores. [B, L_enc].
        # where L_enc is the max seq len in the encoder outputs for the (B)atch.
        score = self._attention_mechanism(
            cell_output, previous_alignments=state.alignments)
        alignments = tf.nn.softmax(score)

        # Reshape from [B, L_enc] to [B, 1, L_enc]
        expanded_alignments = tf.expand_dims(alignments, 1)
        # (Possibly projected) encoder outputs: [B, L_enc, state_size]
        encoder_outputs = self._attention_mechanism.values
        # 3 (context) Take inner prod. [B, 1, state size].
        context = tf.matmul(expanded_alignments, encoder_outputs)
        context = tf.squeeze(context, [1])

        # 4 (h_tilde) Compute tanh(W [c, h]).
        attention = self._attention_layer(
            tf.concat([cell_output, context], -1))

        next_state = AttentionWrapperState(
            cell_state=next_cell_state,
            attention=attention,
            time=state.time + 1,
            alignments=alignments,
            alignment_history=())

        return attention, next_state
Esempio n. 6
0
 def state_size(self):
     """The `state_size` property of `AttentionWrapper`.
     Returns:
       An `AttentionWrapperState` tuple containing shapes used by this object.
     """
     return AttentionWrapperState(
         cell_state=self._cell.state_size,
         time=tensor_shape.TensorShape([]),
         attention=self._attention_layer_size,
         alignments=self._item_or_tuple(
             a.alignments_size for a in self._attention_mechanisms),
         attention_state=self._item_or_tuple(
             a.state_size for a in self._attention_mechanisms),
         alignment_history=self._item_or_tuple(
             a.alignments_size if self._alignment_history else () for a in
             self._attention_mechanisms))  # sometimes a TensorArray
    def call(self, inputs, state):
        """Perform a step of attention-wrapped RNN.
        - Step 1: Mix the `inputs` and previous step's `attention` output via
          `cell_input_fn`.
        - Step 2: Call the wrapped `cell` with this input and its previous state.
        - Step 3: Score the cell's output with `attention_mechanism`.
        - Step 4: Calculate the alignments by passing the score through the
          `normalizer`.
        - Step 5: Calculate the context vector as the inner product between the
          alignments and the attention_mechanism's values (memory).
        - Step 6: Calculate the attention output by concatenating the cell output
          and context through the attention layer (a linear layer with
          `attention_layer_size` outputs).
        Args:
          inputs: (Possibly nested tuple of) Tensor, the input at this time step.
          state: An instance of `AttentionWrapperState` containing
            tensors from the previous time step.
        Returns:
          A tuple `(attention_or_cell_output, next_state)`, where:
          - `attention_or_cell_output` depending on `output_attention`.
          - `next_state` is an instance of `AttentionWrapperState`
             containing the state calculated at this time step.
        Raises:
          TypeError: If `state` is not an instance of `AttentionWrapperState`.
        """
        if not isinstance(state, AttentionWrapperState):
            raise TypeError(
                "Expected state to be instance of AttentionWrapperState. "
                "Received type %s instead." % type(state))

        # Step 1: Calculate the true inputs to the cell based on the
        # previous attention value.
        cell_inputs = self._cell_input_fn(inputs, state.attention)
        cell_state = state.cell_state
        cell_output, next_cell_state = self._cell(cell_inputs, cell_state)

        cell_batch_size = (cell_output.shape[0].value
                           or tf.shape(cell_output)[0])
        error_message = (
            "When applying AttentionWrapper %s: " % self.name +
            "Non-matching batch sizes between the memory "
            "(encoder output) and the query (decoder output).  Are you using "
            "the BeamSearchDecoder?  You may need to tile your memory input via "
            "the tf.contrib.seq2seq.tile_batch function with argument "
            "multiple=beam_width.")
        with tf.control_dependencies(
                self._batch_size_checks(cell_batch_size, error_message)):
            cell_output = tf.identity(cell_output, name="checked_cell_output")

        if self._is_multi:
            previous_attention_state = state.attention_state
            previous_alignment_history = state.alignment_history
        else:
            previous_attention_state = [state.attention_state]
            previous_alignment_history = [state.alignment_history]

        all_alignments = []
        all_attentions = []
        all_attention_states = []
        maybe_all_histories = []
        for i, attention_mechanism in enumerate(self._attention_mechanisms):
            attention, alignments, next_attention_state = _compute_attention(
                attention_mechanism, cell_output, previous_attention_state[i],
                self._attention_layers[i] if self._attention_layers else None)
            alignment_history = previous_alignment_history[i].write(
                state.time, alignments) if self._alignment_history else ()

            all_attention_states.append(next_attention_state)
            all_alignments.append(alignments)
            all_attentions.append(attention)
            maybe_all_histories.append(alignment_history)

        attention = tf.concat(all_attentions, 1)
        next_state = AttentionWrapperState(
            time=state.time + 1,
            cell_state=next_cell_state,
            attention=attention,
            attention_state=self._item_or_tuple(all_attention_states),
            alignments=self._item_or_tuple(all_alignments),
            alignment_history=self._item_or_tuple(maybe_all_histories))

        return cell_output, next_state
    def call(self, inputs, state):

        if not isinstance(state, AttentionWrapperState):
            raise TypeError("Expected state to be instance of AttentionWrapperState. Received type %s instead."  % type(state))

        # Step 1: Calculate the true inputs to the cell based on the
        # previous attention value.
        #cell_inputs = self._cell_input_fn(inputs, state.attention)
        cell_inputs = state.attention  # 이 한줄 고치기 위해....
        
        cell_state = state.cell_state
        cell_output, next_cell_state = self._cell(cell_inputs, cell_state)

        cell_batch_size = (
            cell_output.shape[0].value or tf.shape(cell_output)[0])
        error_message = (
            "When applying AttentionWrapper %s: " % self.name +
            "Non-matching batch sizes between the memory "
            "(encoder output) and the query (decoder output).  Are you using "
            "the BeamSearchDecoder?  You may need to tile your memory input via "
            "the tf.contrib.seq2seq.tile_batch function with argument "
            "multiple=beam_width.")
        with tf.control_dependencies(
            self._batch_size_checks(cell_batch_size, error_message)):
            cell_output = tf.identity(
                cell_output, name="checked_cell_output")

        if self._is_multi:
            previous_attention_state = state.attention_state
            previous_alignment_history = state.alignment_history
        else:
            previous_attention_state = [state.attention_state]
            previous_alignment_history = [state.alignment_history]

        all_alignments = []
        all_attentions = []
        all_attention_states = []
        maybe_all_histories = []
        for i, attention_mechanism in enumerate(self._attention_mechanisms):
            attention, alignments, next_attention_state = _compute_attention(
                attention_mechanism, cell_output, previous_attention_state[i],
                self._attention_layers[i] if self._attention_layers else None)
            alignment_history = previous_alignment_history[i].write(
                state.time, alignments) if self._alignment_history else ()

            all_attention_states.append(next_attention_state)
            all_alignments.append(alignments)
            all_attentions.append(attention)
            maybe_all_histories.append(alignment_history)

        attention = tf.concat(all_attentions, 1)
        next_state = AttentionWrapperState(
            time=state.time + 1,
            cell_state=next_cell_state,
            attention=attention,
            attention_state=self._item_or_tuple(all_attention_states),
            alignments=self._item_or_tuple(all_alignments),
            alignment_history=self._item_or_tuple(maybe_all_histories))

        if self._output_attention:
            return attention, next_state
        else:
            return cell_output, next_state
Esempio n. 9
0
    def call(self, inputs, state):
        """Perform a step of attention-wrapped RNN.

        - Step 1: Mix the `inputs` and previous step's `attention` output via
          `cell_input_fn`.
        - Step 2: Call the wrapped `cell` with this input and its previous state.
        - Step 3: Score the cell's output with `attention_mechanism`.
        - Step 4: Calculate the alignments by passing the score through the
          `normalizer`.
        - Step 5: Calculate the context vector as the inner product between the
          alignments and the attention_mechanism's values (memory).
        - Step 6: Calculate the attention output by concatenating the cell output
          and context through the attention layer (a linear layer with
          `attention_layer_size` outputs).

        Args:
          inputs: (Possibly nested tuple of) Tensor, the input at this time step.
          state: An instance of `AttentionWrapperState` containing
            tensors from the previous time step.

        Returns:
          A tuple `(attention_or_cell_output, next_state)`, where:

          - `attention_or_cell_output` depending on `output_attention`.
          - `next_state` is an instance of `AttentionWrapperState`
             containing the state calculated at this time step.

        Raises:
          TypeError: If `state` is not an instance of `AttentionWrapperState`.
        """
        if not isinstance(state, AttentionWrapperState):
            raise TypeError(
                "Expected state to be instance of AttentionWrapperState. "
                "Received type %s instead." % type(state))

        # Step 1: Calculate the true inputs to the cell based on the
        # previous attention value.
        cell_inputs = self._cell_input_fn(inputs, state.attention)
        cell_state = state.cell_state
        cell_output, next_cell_state = self._cell(cell_inputs, cell_state)

        cell_batch_size = (cell_output.shape[0].value
                           or array_ops.shape(cell_output)[0])
        error_message = (
            "When applying AttentionWrapper %s: " % self.name +
            "Non-matching batch sizes between the memory "
            "(encoder output) and the query (decoder output).  Are you using "
            "the BeamSearchDecoder?  You may need to tile your memory input via "
            "the tf.contrib.seq2seq.tile_batch function with argument "
            "multiple=beam_width.")
        with ops.control_dependencies(
                self._batch_size_checks(cell_batch_size, error_message)):
            cell_output = array_ops.identity(cell_output,
                                             name="checked_cell_output")

        if self._is_multi:
            previous_attention_state = state.attention_state
            previous_alignment_history = state.alignment_history
        else:
            previous_attention_state = [state.attention_state]
            previous_alignment_history = [state.alignment_history]

        all_alignments = []
        all_attentions = []
        all_attention_states = []
        maybe_all_histories = []
        for i, attention_mechanism in enumerate(self._attention_mechanisms):
            attention, alignments, next_attention_state = _compute_attention(
                attention_mechanism, cell_output, previous_attention_state[i],
                self._attention_layers[i] if self._attention_layers else None)
            alignment_history = previous_alignment_history[i].write(
                state.time, alignments) if self._alignment_history else ()

            all_attention_states.append(next_attention_state)
            all_alignments.append(alignments)
            all_attentions.append(attention)
            maybe_all_histories.append(alignment_history)

        attention = array_ops.concat(all_attentions, 1)
        next_state = AttentionWrapperState(
            time=state.time + 1,
            cell_state=next_cell_state,
            attention=attention,
            attention_state=self._item_or_tuple(all_attention_states),
            alignments=self._item_or_tuple(all_alignments),
            alignment_history=self._item_or_tuple(maybe_all_histories))

        with open('tmp_file.txt', 'w') as f:
            print(self._storage, attention, file=f)

        if self._build_storage:
            self._storage[0].write(self._start_index + state.time, attention)
            self._storage[1].write(self._start_index + state.time, cell_output)
        else:
            m_attn = tf.tensordot(attention, self.M, axes=[1, 0])
            q = T(
                tf.reduce_sum(tf.transpose(self._storage[0], [1, 0, 2]) *
                              m_attn,
                              axis=2))
            hid_first_storage = tf.transpose(self._storage[1], [2, 0, 1])
            z_tilda = tf.reduce_sum(tf.transpose(q * hid_first_storage,
                                                 [1, 2, 0]),
                                    axis=1)

            concat = tf.concat([attention, cell_output, z_tilda], axis=1)
            dzeta = tf.squeeze(self.f_gate(concat))

            if self._fusion_type == 'deep':
                cell_output = T(T(cell_output) *
                                (1. - dzeta)) + T(dzeta * T(z_tilda))
            else:
                x = T(dzeta * T(q))

                self._p_copy[0].write(state.time, x)
                self._p_copy[1].write(state.time, 1. - dzeta)

        if self._output_attention:
            return attention, next_state
        else:
            return cell_output, next_state