Exemplo n.º 1
0
    def __init__(self,
                 num_units,
                 memory,
                 memory_sequence_length=None,
                 normalize=False,
                 probability_fn=None,
                 score_mask_value=None,
                 dtype=None,
                 name_or_scope='BahdanauAttention'):
        """ Construct the Attention mechanism.
            :param num_units: The depth of the query mechanism.
            :param memory: The memory to query; usually the output of an RNN encoder. This tensor should be
                           shaped `[batch_size, max_time, ...]`.
            :param memory_sequence_length: (optional): Sequence lengths for the batch entries in memory. If provided,
                                           the memory tensor rows are masked with zeros for values past the respective
                                           sequence lengths.
            :param normalize: Python boolean. Whether to normalize the energy term.
            :param probability_fn: (optional) A `callable`. Converts the score to probabilities. The default is
                                   @{tf.nn.softmax}. Other options include @{tf.contrib.seq2seq.hardmax} and
                                   @{tf.contrib.sparsemax.sparsemax}.
                                   Its signature should be: `probabilities = probability_fn(score)`.
            :param score_mask_value: (optional): The mask value for score before passing into `probability_fn`.
                                     The default is -inf. Only used if `memory_sequence_length` is not None.
            :param dtype: The data type for the query and memory layers of the attention mechanism.
            :param name_or_scope: String or VariableScope to use when creating ops.
        """
        # pylint: disable=too-many-arguments
        if probability_fn is None:
            probability_fn = nn_ops.softmax
        if dtype is None:
            dtype = dtypes.float32
        wrapped_probability_fn = lambda score, _: probability_fn(score)

        self._num_units = num_units
        self._normalize = normalize
        self._name_or_scope = name_or_scope

        with variable_scope.variable_scope(name_or_scope,
                                           default_name='BahdanauAttention'):
            super(BahdanauAttention,
                  self).__init__(query_layer=core.Dense(num_units,
                                                        name='query_layer',
                                                        use_bias=False,
                                                        dtype=dtype),
                                 memory_layer=core.Dense(num_units,
                                                         name='memory_layer',
                                                         use_bias=False,
                                                         dtype=dtype),
                                 memory=memory,
                                 probability_fn=wrapped_probability_fn,
                                 memory_sequence_length=memory_sequence_length,
                                 score_mask_value=score_mask_value)
Exemplo n.º 2
0
    def _compute_attention(self, alignments, memory):
        """Computes the attention and alignments for a given attention_mechanism."""
        # Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time]
        expanded_alignments = array_ops.expand_dims(alignments, 1)

        # Context is the inner product of alignments and values along the
        # memory time dimension.
        # alignments shape is  [batch_size, 1, memory_time]
        # memory is [batch_size, memory_time, memory_size]
        # the batched matmul is over memory_time, so the output shape is [batch_size, 1, memory_size].
        # we then squeeze out the singleton dim.
        context = math_ops.matmul(expanded_alignments, memory)
        context = array_ops.squeeze(context, [1])
        attn_layer = lambda x: x
        if self._attention_layer_size != self._memory_size:
            attn_layer = core.Dense(self._attention_layer_size,
                                    name='attn_layer',
                                    use_bias=False,
                                    dtype=context.dtype)
        attention = attn_layer(context)
        return attention, alignments
Exemplo n.º 3
0
    def _compute_attention(self, query, memory):
        """ Computes the attention and alignments for the Bahdanau attention mechanism .
            :param query: The query (inputs) to use to compute attention. Size [b, input_size]
            :param memory: The memory (previous outputs) used to compute attention [b, time_step, memory_size]
            :return: The attention. Size [b, attn_size]
        """
        assert len(
            memory.shape) == 3, 'Memory needs to be [batch, time, memory_size]'
        memory_time = array_ops.shape(memory)[1]
        memory_size = memory.shape[2]
        num_units = self._num_units
        assert self._memory_size == memory_size, 'Expected mem size of %s - Got %s' % (
            self._memory_size, memory_size)

        # Query, memory, and attention layers
        query_layer = core.Dense(num_units,
                                 name='query_layer',
                                 use_bias=False,
                                 dtype=self._dtype)
        memory_layer = lambda x: x
        if memory_size != self._num_units:
            memory_layer = core.Dense(num_units,
                                      name='memory_layer',
                                      use_bias=False,
                                      dtype=self._dtype)
        attn_layer = lambda x: x
        if self._attention_layer_size is not None and memory_size != self._attention_layer_size:
            attn_layer = core.Dense(self._attention_layer_size,
                                    name='attn_layer',
                                    use_bias=False,
                                    dtype=self._dtype)

        # Masking memory
        sequence_length = gen_math_ops.minimum(memory_time,
                                               self._sequence_length)
        sequence_mask = array_ops.sequence_mask(sequence_length,
                                                maxlen=memory_time,
                                                dtype=dtypes.float32)[...,
                                                                      None]
        values = memory * sequence_mask
        keys = memory_layer(values)

        # Computing scores
        processed_query = query_layer(query)
        scores = _bahdanau_score(processed_query, keys, self._normalize)

        # Getting alignments
        masked_scores = _maybe_mask_score(scores, sequence_length,
                                          self._score_mask_value)
        alignments = self._wrapped_probability_fn(masked_scores,
                                                  None)  # [batch, time]

        # Getting attention
        expanded_alignments = array_ops.expand_dims(alignments,
                                                    1)  # [batch, 1, time]
        context = math_ops.matmul(expanded_alignments,
                                  memory)  # [batch, 1, memory_size]
        context = array_ops.squeeze(context, [1])  # [batch, memory_size]
        attention = attn_layer(context)  # [batch, attn_size]

        # Returning attention
        return attention
Exemplo n.º 4
0
    def __init__(self,
                 cell,
                 attention_mechanism,
                 attention_layer_size=None,
                 alignment_history=False,
                 cell_input_fn=None,
                 output_attention=True,
                 initial_cell_state=None,
                 name_or_scope='AttentionWrapper',
                 attention_layer=None):
        """ Construct the `AttentionWrapper`.

            :param cell: An instance of `RNNCell`.
            :param attention_mechanism: A list of `AttentionMechanism` instances or a singleinstance.
            :param attention_layer_size: A list of Python integers or a single Python integer,
                                         the depth of the attention (output) layer(s).
            :param alignment_history: Python boolean, whether to store alignment history from all time steps in the
                                      final output state
            :param cell_input_fn: (optional) A `callable`. The default is: concat([inputs, attention], axis=-1)
            :param output_attention: Python bool. If `True` (default), the output at each time step is the attn value.
            :param initial_cell_state: The initial state value to use for the cell when the user calls `zero_state()`.
            :param name_or_scope: String or VariableScope to use when creating ops.
            :param attention_layer: A list of `tf.layers.Layer` instances or a single `tf.layers.Layer` instance taking
                                    the context and cell output as inputs to generate attention at each time step.
                                    If None (default), use the context as attention at each time step.

            **NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in `AttentionWrapper`,
                     then you must ensure that:

            - The encoder output has been tiled to `beam_width` via `tf.contrib.seq2seq.tile_batch` (NOT `tf.tile`).
            - The `batch_size` argument passed to the `zero_state` method of this wrapper is equal to
                `true_batch_size * beam_width`.
            - The initial state created with `zero_state` above contains a `cell_state` value containing properly
                tiled final state from the encoder.
        """
        # pylint: disable=too-many-arguments
        self._name_or_scope = name_or_scope
        with variable_scope.variable_scope(name_or_scope, 'AttentionWrapper'):
            super(AttentionWrapper, self).__init__()
            rnn_cell_impl.assert_like_rnncell("cell", cell)

            # Attention mechanism
            if isinstance(attention_mechanism, (list, tuple)):
                self._is_multi = True
                attention_mechanisms = attention_mechanism
                for attn_mechanism in attention_mechanisms:
                    if not isinstance(attn_mechanism, AttentionMechanism):
                        raise TypeError(
                            'attention_mechanism must contain only instances of AttentionMechanism, saw '
                            'type: %s' % type(attn_mechanism).__name__)
            else:
                self._is_multi = False
                if not isinstance(attention_mechanism, AttentionMechanism):
                    raise TypeError(
                        'attention_mechanism must be an AttentionMechanism or list of multiple '
                        'AttentionMechanism instances, saw type: %s' %
                        type(attention_mechanism).__name__)
                attention_mechanisms = (attention_mechanism, )

            # Cell input function
            if cell_input_fn is None:
                cell_input_fn = lambda inputs, attention: array_ops.concat(
                    [inputs, attention], -1)
            else:
                if not callable(cell_input_fn):
                    raise TypeError(
                        'cell_input_fn must be callable, saw type: %s' %
                        type(cell_input_fn).__name__)

            # Attention layer size
            if attention_layer_size is not None and attention_layer is not None:
                raise ValueError(
                    'Only one of attention_layer_size and attention_layer should be set'
                )

            if attention_layer_size is not None:
                attention_layer_sizes = tuple(
                    attention_layer_size if isinstance(attention_layer_size, (
                        list, tuple)) else (attention_layer_size, ))
                if len(attention_layer_sizes) != len(attention_mechanisms):
                    raise ValueError(
                        'If provided, attention_layer_size must contain exactly one integer per '
                        'attention_mechanism, saw: %d vs %d' %
                        (len(attention_layer_sizes),
                         len(attention_mechanisms)))
                self._attention_layers = tuple(
                    core.Dense(attention_layer_size,
                               name='attention_layer',
                               use_bias=False,
                               dtype=attention_mechanisms[i].dtype) for i,
                    attention_layer_size in enumerate(attention_layer_sizes))
                self._attention_layer_size = sum(attention_layer_sizes)

            elif attention_layer is not None:
                self._attention_layers = tuple(attention_layer if isinstance(
                    attention_layer, (list, tuple)) else (attention_layer, ))
                if len(self._attention_layers) != len(attention_mechanisms):
                    raise ValueError(
                        'If provided, attention_layer must contain exactly one layer per '
                        'attention_mechanism, saw: %d vs %d' %
                        (len(self._attention_layers),
                         len(attention_mechanisms)))
                self._attention_layer_size = \
                    sum(tensor_shape.dimension_value(
                        layer.compute_output_shape([None,
                                                    cell.output_size
                                                    + tensor_shape.dimension_value(mechanism.values.shape[-1])])[-1])
                        for layer, mechanism in zip(self._attention_layers, attention_mechanisms))
            else:
                self._attention_layers = None
                self._attention_layer_size = sum(
                    tensor_shape.dimension_value(
                        attention_mechanism.values.shape[-1])
                    for attention_mechanism in attention_mechanisms)

            self._cell = cell
            self._attention_mechanisms = attention_mechanisms
            self._cell_input_fn = cell_input_fn
            self._output_attention = output_attention
            self._alignment_history = alignment_history

            if initial_cell_state is None:
                self._initial_cell_state = None
            else:
                final_state_tensor = nest.flatten(initial_cell_state)[-1]
                state_batch_size = (tensor_shape.dimension_value(
                    final_state_tensor.shape[0])
                                    or array_ops.shape(final_state_tensor)[0])
                error_message = (
                    'When constructing AttentionWrapper %s: ' % self._base_name
                    +
                    'Non-matching batch sizes between the memory (encoder output) and initial_cell_state. '
                    'Are you using the BeamSearchDecoder? You may need to tile your initial state via the '
                    'tf.contrib.seq2seq.tile_batch function with argument multiple=beam_width.'
                )

                with ops.control_dependencies(
                        self._batch_size_checks(state_batch_size,
                                                error_message)):
                    self._initial_cell_state = \
                        nest.map_structure(lambda state: array_ops.identity(state, name='check_initial_cell_state'),
                                           initial_cell_state)
Exemplo n.º 5
0
    def _step(self, inputs, past_attns, time, feeder_cell, feeder_state):
        """ Performs the block operation on n-layers
            :param inputs: The tensor inputs (embedding of each word) - [batch, seq_len, emb_size]
            :param past_attns: The past attentions - [batch, nb_layers, 2, nb_heads. past_length, emb_size // nb_heads]
            :param time: A tensor representing the current time step
            :param feeder_cell: None or A feeder cell that returns a RNN cell output to use for conditioning
            :param feeder_state: None or the initial state of the feeder cell
            :param name: Name of the scope - To share weights between calls
            :return: A tuple consisting of:
                        1) The cell outputs - [batch, seq_len, emb_size]
                        2) The present attention - [batch, nb_layers, 2, nb_heads. seq_len, emb_size // nb_heads]
                        3) The new state of the feeder cell
        """
        with variable_scope.variable_scope(self._scope, default_name='step'):
            past_length = array_ops.shape(past_attns)[
                -2]  # How many past attention steps we have
            seq_len = array_ops.shape(inputs)[
                -2]  # How many steps are we computing for the current time
            emb_size = inputs.shape[-1].value  # The size of the embedding
            assert emb_size == self._emb_size, 'Expected an embedding size of %d' % self._emb_size

            # 1) Computing the word embedding of each token
            assert inputs.shape.ndims == 3, 'Expected [batch, seq_len, emb_size]'  # [bz, seq, emb]
            out_h = inputs

            # 2) Computing the position embedding of each token
            # If we know the context was padded, the effective past length is the context length + nb of time steps
            if self._past_seq_lengths is not None:
                past_length = gen_math_ops.minimum(
                    past_length,
                    self._past_seq_lengths + time)[:, None]  # [bz, 1]
            else:
                past_length = gen_array_ops.fill([self._batch_size, 1],
                                                 value=past_length)  # [bz, 1]
            step_ix = math_ops.range(seq_len)[None, :]  # [1, seq_len]
            token_positions = gen_math_ops.add(past_length,
                                               step_ix)  # [batch, seq_len]
            token_positions = gen_math_ops.minimum(
                self._position_emb_size - 1,
                token_positions)  # [batch, seq_len]
            h_pos = self._position_embedding_fn(
                token_positions)  # [bz, seq, emb]
            out_h = out_h + h_pos

            # 3) If we have a feeder cell, we also need to condition 'h' on it.
            next_feeder_state = feeder_state
            if feeder_cell is not None:
                assert feeder_state is not None, 'A feeder state is required if a feeder cell is provided.'
                assert inputs.shape[
                    1].value == 1, 'The seq dimension must be 1 to use a feeder_cell'
                feeder_outputs, next_feeder_state = feeder_cell(
                    array_ops.squeeze(inputs, axis=1), feeder_state)
                h_feed = feeder_outputs  # [bz, feeder_sz]
                if feeder_outputs.shape[-1].value != emb_size:
                    h_feed = core.Dense(emb_size,
                                        activation=None,
                                        name='h_feed')(h_feed)  # [bz, emb]
                h_feed = gen_array_ops.tile(h_feed[:, None, :],
                                            [1, seq_len, 1])  # [bz, seq, emb]
                out_h = out_h + h_feed

            # Transformer
            presents = []
            pasts = array_ops.unstack(
                past_attns,
                axis=1)  # list of [batch, 2, heads, past_len, head_sz]
            assert len(
                pasts
            ) == self._nb_layers, 'Expected the past attention to have %d layers.' % self._nb_layers

            for layer_ix, past_attn in enumerate(pasts):
                out_h, present = self._block(out_h, past_attn,
                                             'layer.%d' % layer_ix)
                presents += [present]
            presents = array_ops.stack(presents, axis=1)

            # Normalizing and returning
            cell_outputs = self._norm(out_h, 'norm_h')  # [batch, seq, emb]
            return cell_outputs, presents, next_feeder_state