예제 #1
0
    def _conv1d(inputs, output_dim, name=None):
        """ Performs 1d convolution
            :param inputs: The tensor inputs - [axis_0, ..., axis_n-1, input_dim]
            :param output_dim: The size of the 1d convoltuion
            :param name: Name of the scope - To share weights between calls
            :return: The tensors after the convolution - [axis_0, ..., axis_n-1, output_dim]
        """
        with variable_scope.variable_scope(name):
            input_dims = [
                array_ops.shape(inputs)[axis]
                if dim.value is None else dim.value
                for axis, dim in enumerate(inputs.shape)
            ]
            input_prefix_dims, input_last_dim = input_dims[:-1], input_dims[-1]
            weight = variable_scope.get_variable(
                'w',
                shape=[input_last_dim, output_dim],
                initializer=init_ops.random_normal_initializer(0.02))
            beta = variable_scope.get_variable(
                'b',
                shape=[output_dim],
                initializer=init_ops.constant_initializer(0.))

            inputs = gen_array_ops.reshape(
                inputs, [-1, input_last_dim])  # [B, input_last_dim]
            outputs = math_ops.matmul(inputs, weight) + beta  # [B, output_dim]
            return gen_array_ops.reshape(outputs, input_prefix_dims +
                                         [output_dim])  # [..., output_dim]
예제 #2
0
    def call(self, inputs, state):  # pylint: disable=arguments-differ
        """ Perform a step of attention-wrapped RNN
            :param inputs: (Possibly nested tuple of) Tensor, the input at this time step.
            :param state: An instance of `SelfAttentionWrapperState` containing tensors from the previous time step.
            :return: A tuple `(attention_or_cell_output, next_state)`, where:
                    - `attention_or_cell_output` depending on `output_attention`.
                    - `next_state` is an instance of `SelfAttentionWrapperState` containing the state calculated at
                       this time step.
        """
        if not isinstance(state, SelfAttentionWrapperState):
            raise TypeError(
                'Expected state to be instance of AttentionWrapperState. Received type %s instead.'
                % type(state))

        # Getting batch size
        batch_size = array_ops.shape(inputs)[0]
        assert len(inputs.shape) == 2, 'Expected inputs to be of rank 2'

        def get_next_memory_and_attn():
            """ Gets the next memory and attention """
            next_memory = array_ops.concat(
                [
                    state.memory,  # [b, t, mem_size]
                    array_ops.expand_dims(self._input_fn(inputs), axis=1)
                ],
                axis=1)
            next_attention = self._compute_attention(inputs, next_memory)
            with ops.control_dependencies([next_memory, next_attention]):
                return array_ops.identity(next_memory), array_ops.identity(
                    next_attention)

        def get_zero_memory_and_attn():
            """ Time = 0, we don't concatenate to memory and attention is all 0. """
            next_memory = state.memory
            next_attention = array_ops.zeros(
                [batch_size, self._attention_layer_size], dtype=inputs.dtype)
            with ops.control_dependencies([next_memory, next_attention]):
                return array_ops.identity(next_memory), array_ops.identity(
                    next_attention)

        # Computing memory and attention
        memory, attention = control_flow_ops.cond(
            gen_math_ops.equal(state.time, 0),
            true_fn=get_zero_memory_and_attn,
            false_fn=get_next_memory_and_attn)

        # Calculate the true inputs to the cell based on the previous attention value.
        cell_inputs = self._cell_input_fn(inputs, attention)
        cell_state = state.cell_state
        cell_output, cell_state = self._cell(cell_inputs, cell_state)

        # Extracting computed context
        next_state = SelfAttentionWrapperState(cell_state=cell_state,
                                               time=state.time + 1,
                                               memory=memory)

        # Returning cell output or attention
        if self._output_attention:
            return attention, next_state
        return cell_output, next_state
예제 #3
0
 def greedy():
     """ Selecting greedy """
     argmax_id = math_ops.cast(math_ops.argmax(cell_outputs, axis=-1), dtypes.int32)
     nb_candidate = array_ops.shape(candidate)[1]
     candidate_ids = \
         math_ops.reduce_sum(array_ops.one_hot(argmax_id, nb_candidate, dtype=dtypes.int32) * candidate,
                             axis=-1)
     with ops.control_dependencies([candidate_ids]):
         return array_ops.identity(candidate_ids)
예제 #4
0
    def _merge_heads(self, inputs):
        """ Merges the attn heads of the tensor into a single dimension
            :param inputs: The tensor to merge - [batch, nb_heads, seq_len, head_size]
            :return: A tensor in the format - [batch, seq_len, nb_heads * head_size]
        """
        assert inputs.shape.ndims == 4, 'Expected inputs to be [batch, nb_heads, seq_len, head_size]'
        assert inputs.shape[
            1].value == self._nb_heads, 'Expected the 2nd dimension to be the number of heads'

        # Transposing to [batch, seq_len, nb_heads, head_size]
        inputs = array_ops.transpose(inputs, [0, 2, 1, 3])

        # Merging last 2 dims
        batch_size = array_ops.shape(inputs)[0]
        seq_len = array_ops.shape(inputs)[1]
        head_size = inputs.shape[-1].value
        return gen_array_ops.reshape(
            inputs, [batch_size, seq_len, self._nb_heads * head_size])
예제 #5
0
    def _split_in_heads(self, inputs):
        """ Splits the tensor into heads of size attn_dim / heads
            :param inputs: The tensor to split - [batch, seq_len, attn_dim]
            :return: A tensor in the format - [batch, nb_heads, seq_len, attn_dim // nb_heads]
        """
        assert inputs.shape.ndims == 3, 'Expected inputs to be [batch, seq_len, attn_dim]'
        attn_dim = inputs.shape[-1].value
        assert attn_dim % self._nb_heads == 0, 'The attn_dim must be evenly divisible by the nb of heads'

        # Reshaping to [batch, seq_len, nb_heads, head_size]
        batch_size = array_ops.shape(inputs)[0]
        seq_len = array_ops.shape(inputs)[1]
        head_size = attn_dim // self._nb_heads
        inputs = gen_array_ops.reshape(
            inputs, [batch_size, seq_len, self._nb_heads, head_size])

        # Transposing to [batch, nb_heads, seq_len, head_size]
        return array_ops.transpose(inputs, [0, 2, 1, 3])
예제 #6
0
    def _mask_attn_weights(self, attn_weights):
        """ Masks the attention weights
            :param attn_weights: The attention weights - [batch, nb_head, seq_len, seq_len + past_length]
            :return: A tensor of 0 and 1. of the same shape and dtype as attn_weights
        """
        seq_len = array_ops.shape(attn_weights)[-2]
        total_len = array_ops.shape(attn_weights)[-1]

        # 1) Creating the attention mask matrix (with the lower triangle set to 1. on the right)
        # e.g. if seq_len == 3, and total_len == 10
        # the attention mask would be:       - [seq_len, total_len]
        # [[1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],
        #  [1., 1., 1., 1., 1., 1., 1., 1., 1., 0.],
        #  [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]
        num_lower = math_ops.cast(-1, dtypes.int32)
        num_upper = total_len - seq_len
        attn_mask = gen_array_ops.matrix_band_part(
            array_ops.ones([seq_len, total_len]), num_lower, num_upper)

        # No past_attentions/context - We just add two leading dimensions to attn_mask and can return it
        if self._past_seq_lengths is None:
            return attn_mask[None, None, :, :]

        # If we have a context with varying sequence length, we also need to mask the items after the end of sequence
        # e.g.
        # [[1., 1., 1., 0., 0., 0., 0., 1., 1., 1.],            # => length of 3 (padded to 7) + seq_len of 3
        #  [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],            # => length of 7 (padded to 7) + seq_len of 3
        #  [1., 1., 1., 1., 1., 0., 0., 1., 1., 1.]]            # => length of 5 (padded to 7) + seq_len of 3
        #
        # The resulting attention mask would be the product of the two.
        # [[1., 1., 1., 0., 0., 0., 0., 1., 0., 0.],
        #  [1., 1., 1., 1., 1., 1., 1., 1., 1., 0.],
        #  [1., 1., 1., 1., 1., 0., 0., 1., 1., 1.]]
        seq_mask = array_ops.sequence_mask(
            self._past_seq_lengths, dtype=dtypes.float32)  # [b, max_len]
        seq_mask = pad_axis(seq_mask, axis=-1,
                            min_size=total_len)  # [b, total_len]

        # Returning the multiplication of the two masks
        return gen_math_ops.mul(attn_mask[None, None, :, :],
                                seq_mask[:, None,
                                         None, :])  # [b, nb_heads, seq, total]
예제 #7
0
 def sample():
     """ Sampling """
     logits = cell_outputs if self._softmax_temperature is None else cell_outputs / self._softmax_temperature
     sample_id_sampler = categorical.Categorical(logits=logits)
     sample_ids = sample_id_sampler.sample(seed=self._seed)
     nb_candidate = array_ops.shape(candidate)[1]
     reduce_op = math_ops.reduce_sum(array_ops.one_hot(sample_ids,
                                                       nb_candidate,
                                                       dtype=dtypes.int32) * candidate, axis=-1)
     with ops.control_dependencies([reduce_op]):
         return array_ops.identity(reduce_op)
예제 #8
0
        def body(time, outputs_ta, state, inputs, finished, sequence_lengths):
            """ Internal while_loop body. """
            (next_outputs, decoder_state, next_inputs,
             decoder_finished) = decoder.step(time, inputs, state)
            if decoder.tracks_own_finished:
                next_finished = decoder_finished
            else:
                next_finished = gen_math_ops.logical_or(
                    decoder_finished, finished)
            next_sequence_lengths = array_ops.where(
                gen_math_ops.logical_not(finished),
                gen_array_ops.fill(array_ops.shape(sequence_lengths),
                                   time + 1), sequence_lengths)

            nest.assert_same_structure(state, decoder_state)
            nest.assert_same_structure(outputs_ta, next_outputs)
            nest.assert_same_structure(inputs, next_inputs)

            # Zero out output values past finish
            if impute_finished:
                emit = nest.map_structure(
                    lambda out, zero: array_ops.where(finished, zero, out),
                    next_outputs, zero_outputs)
            else:
                emit = next_outputs

            # Copy through states past finish
            def _maybe_copy_state(new, cur):
                # TensorArrays, multiple dynamic dims, and scalar states get passed through.
                if isinstance(cur, tensor_array_ops.TensorArray):
                    pass_through = True
                elif None in new.shape.as_list()[1:]:
                    pass_through = True
                else:
                    new.set_shape(cur.shape)
                    pass_through = (new.shape.ndims == 0)
                return new if pass_through else array_ops.where(
                    finished, cur, new)

            if impute_finished:
                next_state = nest.map_structure(_maybe_copy_state,
                                                decoder_state, state)
            else:
                next_state = decoder_state

            outputs_ta = nest.map_structure(
                lambda ta, out: ta.write(time, out), outputs_ta, emit)
            return (time + 1, outputs_ta, next_state, next_inputs,
                    next_finished, next_sequence_lengths)
예제 #9
0
    def next_inputs(self, time, inputs, beam_search_output, beam_search_state):
        """ Computes the inputs at the next time step given the beam outputs
            :param time: The current time step (scalar)
            :param inputs: A (structure of) input tensors.
            :param beam_search_output: The output of the beam search step
            :param beam_search_state: The state after the beam search step
            :return: `(beam_search_output, next_inputs)`
            :type beam_search_output: beam_search_decoder.BeamSearchDecoderOutput
            :type beam_search_state: beam_search_decoder.BeamSearchDecoderState
        """
        next_time = time + 1
        all_finished = math_ops.reduce_all(next_time >= self._sequence_length)

        # Sampling
        next_word_ids = beam_search_output.predicted_ids
        candidates = inputs.candidates
        nb_candidates = array_ops.shape(candidates)[1]
        sample_ids = math_ops.reduce_sum(array_ops.one_hot(next_word_ids, nb_candidates, dtype=dtypes.int32)
                                         * array_ops.expand_dims(candidates, axis=1), axis=-1)

        def get_next_inputs():
            """ Retrieves the inputs for the next time step """
            inputs_next_step = sample_ids
            inputs_emb_next_step = self._input_layer(self._order_embedding_fn(inputs_next_step))
            candidate_next_step = self._candidate_tas.read(next_time)
            candidate_emb_next_step = self._candidate_embedding_fn(candidate_next_step)

            # Prevents this branch from executing eagerly
            with ops.control_dependencies([inputs_emb_next_step, candidate_next_step, candidate_emb_next_step]):
                return CandidateInputs(inputs=array_ops.identity(inputs_emb_next_step),
                                       candidates=array_ops.identity(candidate_next_step),
                                       candidates_emb=array_ops.identity(candidate_emb_next_step))

        # Getting next inputs
        next_inputs = control_flow_ops.cond(all_finished,
                                            true_fn=lambda: self._zero_inputs,
                                            false_fn=get_next_inputs)

        # Rewriting beam search output with the correct sample ids
        beam_search_output = beam_search_decoder.BeamSearchDecoderOutput(scores=beam_search_output.scores,
                                                                         predicted_ids=sample_ids,
                                                                         parent_ids=beam_search_output.parent_ids)

        # Returning
        return beam_search_output, next_inputs
예제 #10
0
    def _compute_attention(self, query, memory):
        """ Computes the attention and alignments for the Bahdanau attention mechanism .
            :param query: The query (inputs) to use to compute attention. Size [b, input_size]
            :param memory: The memory (previous outputs) used to compute attention [b, time_step, memory_size]
            :return: The attention. Size [b, attn_size]
        """
        assert len(
            memory.shape) == 3, 'Memory needs to be [batch, time, memory_size]'
        memory_time = array_ops.shape(memory)[1]
        memory_size = memory.shape[2]
        num_units = self._num_units
        assert self._memory_size == memory_size, 'Expected mem size of %s - Got %s' % (
            self._memory_size, memory_size)

        # Query, memory, and attention layers
        query_layer = core.Dense(num_units,
                                 name='query_layer',
                                 use_bias=False,
                                 dtype=self._dtype)
        memory_layer = lambda x: x
        if memory_size != self._num_units:
            memory_layer = core.Dense(num_units,
                                      name='memory_layer',
                                      use_bias=False,
                                      dtype=self._dtype)
        attn_layer = lambda x: x
        if self._attention_layer_size is not None and memory_size != self._attention_layer_size:
            attn_layer = core.Dense(self._attention_layer_size,
                                    name='attn_layer',
                                    use_bias=False,
                                    dtype=self._dtype)

        # Masking memory
        sequence_length = gen_math_ops.minimum(memory_time,
                                               self._sequence_length)
        sequence_mask = array_ops.sequence_mask(sequence_length,
                                                maxlen=memory_time,
                                                dtype=dtypes.float32)[...,
                                                                      None]
        values = memory * sequence_mask
        keys = memory_layer(values)

        # Computing scores
        processed_query = query_layer(query)
        scores = _bahdanau_score(processed_query, keys, self._normalize)

        # Getting alignments
        masked_scores = _maybe_mask_score(scores, sequence_length,
                                          self._score_mask_value)
        alignments = self._wrapped_probability_fn(masked_scores,
                                                  None)  # [batch, time]

        # Getting attention
        expanded_alignments = array_ops.expand_dims(alignments,
                                                    1)  # [batch, 1, time]
        context = math_ops.matmul(expanded_alignments,
                                  memory)  # [batch, 1, memory_size]
        context = array_ops.squeeze(context, [1])  # [batch, memory_size]
        attention = attn_layer(context)  # [batch, attn_size]

        # Returning attention
        return attention
예제 #11
0
    def __init__(self,
                 cell,
                 memory,
                 alignments,
                 sequence_length,
                 probability_fn=None,
                 score_mask_value=None,
                 attention_layer_size=None,
                 cell_input_fn=None,
                 output_attention=False,
                 name=None):
        """ Constructs an AttentionWrapper with static alignments (attention weights)

            :param cell: An instance of `RNNCell`.
            :param memory: The memory to query [batch_size, memory_time, memory_size]
            :param alignments: A tensor of probabilities of shape [batch_size, time_steps, memory_time]
            :param sequence_length: Sequence lengths for the batch entries in memory. Size (b,)
            :param probability_fn: A `callable`.  Converts the score to probabilities.  The default is @{tf.nn.softmax}.
            :param score_mask_value:  The mask value for score before passing into `probability_fn`. Default is -inf.
            :param attention_layer_size: The size of the attention layer. Uses the context if None.
            :param cell_input_fn: (optional) A `callable` to aggregate attention.
                                  Default: `lambda inputs, attention: array_ops.concat([inputs, attention], -1)`.
            :param output_attention: If true, outputs the attention, if False outputs the cell output.
            :param name: name: Name to use when creating ops.
        """
        # pylint: disable=too-many-arguments
        # Initializing RNN Cell
        super(StaticAttentionWrapper, self).__init__(name=name)
        rnn_cell_impl.assert_like_rnncell('cell', cell)

        # Setting values
        self._cell = cell
        self._memory = memory
        self._attention_layer_size = attention_layer_size
        self._output_attention = output_attention
        self._memory_time = alignments.get_shape()[-1].value
        self._memory_size = memory.get_shape()[-1].value
        self._sequence_length = sequence_length

        # Validating attention layer size
        if self._attention_layer_size is None:
            self._attention_layer_size = self._memory_size

        # Validating cell_input_fn
        if cell_input_fn is None:
            cell_input_fn = lambda inputs, attention: array_ops.concat(
                [inputs, attention], -1)
        else:
            if not callable(cell_input_fn):
                raise TypeError(
                    'cell_input_fn must be callable, saw type: %s' %
                    type(cell_input_fn).__name__)
        self._cell_input_fn = cell_input_fn

        # Probability Function
        if probability_fn is None:
            probability_fn = nn_ops.softmax
        if score_mask_value is None:
            score_mask_value = dtypes.as_dtype(
                self._memory.dtype).as_numpy_dtype(-np.inf)
        self._probability_fn = lambda score, _: probability_fn(
            _maybe_mask_score(score, sequence_length, score_mask_value), _)

        # Storing alignments as TA
        # Padding with 1 additional zero, to prevent error on read(0)
        alignments = array_ops.pad(alignments, [(0, 0), (0, 1), (0, 0)])
        alignments = nest.map_structure(
            _transpose_batch_time,
            alignments)  # (max_time + 1, b, memory_time)
        self._alignments_ta = nest.map_structure(
            _unstack_ta, alignments)  # [time_step + 1, batch, memory_time]
        self._initial_alignment = self._alignments_ta.read(0)
        self._initial_attention = self._compute_attention(
            self._initial_alignment, self._memory)[0]

        # Storing zero inputs
        batch_size = array_ops.shape(memory)[0]
        self._zero_cell_output = array_ops.zeros(
            [batch_size, cell.output_size])
        self._zero_attention = array_ops.zeros(
            [batch_size, self._attention_layer_size])
        self._zero_state = self.zero_state(batch_size, dtypes.float32)
        self._zero_alignment = array_ops.zeros_like(self._initial_alignment)
예제 #12
0
    def call(self, inputs, state):
        """ Performs a step of attention-wrapped RNN.

            1) Mix the `inputs` and previous step's `attention` output via `cell_input_fn`.
            2) Call the wrapped `cell` with this input and its previous state.
            3) Score the cell's output with `attention_mechanism`.
            4) Calculate the alignments by passing the score through the `normalizer`.
            5) Calculate the context vector as the inner product between the alignments and the attention_mechanism's
               values (memory).
            6) Calculate the attention output by concatenating the cell output and context through the attention
               layer (a linear layer with `attention_layer_size` outputs).

            :param inputs: (Possibly nested tuple of) Tensor, the input at this time step.
            :param state: An instance of `AttentionWrapperState` containing tensors from the previous time step.
            :return: A tuple `(attention_or_cell_output, next_state)`, where:
            - `attention_or_cell_output` depending on `output_attention`.
            - `next_state` is an instance of `AttentionWrapperState` containing the state calculated at this time step.
        """
        # pylint: disable=arguments-differ
        if not isinstance(state, AttentionWrapperState):
            raise TypeError(
                'Expected state to be instance of AttentionWrapperState. Rcvd %s instead. '
                % type(state))

        # Step 1: Calculate the true inputs to the cell based on the previous attention value.
        cell_inputs = self._cell_input_fn(inputs, state.attention)
        cell_state = state.cell_state
        cell_output, next_cell_state = self._cell(cell_inputs, cell_state)

        cell_batch_size = tensor_shape.dimension_value(
            cell_output.shape[0]) or array_ops.shape(cell_output)[0]
        error_message = (
            'When applying AttentionWrapper %s: ' % self.name +
            'Non-matching batch sizes between '
            'the memory (encoder output) and the query (decoder output). Are you using the '
            'BeamSearchDecoder? You may need to tile your memory input via the tf.contrib.seq2seq.'
            'tile_batch function with argument multiple=beam_width.')

        with variable_scope.variable_scope(self._name_or_scope,
                                           'AttentionWrapper',
                                           [inputs, state]):

            with ops.control_dependencies(
                    self._batch_size_checks(cell_batch_size, error_message)):
                cell_output = array_ops.identity(cell_output,
                                                 name='checked_cell_output')

            if self._is_multi:
                previous_attention_state = state.attention_state
                previous_alignment_history = state.alignment_history
            else:
                previous_attention_state = [state.attention_state]
                previous_alignment_history = [state.alignment_history]

            # Computing attention
            all_alignments = []
            all_attentions = []
            all_attention_states = []
            maybe_all_histories = []
            for i, attention_mechanism in enumerate(
                    self._attention_mechanisms):
                attention, alignments, next_attention_state = _compute_attention(
                    attention_mechanism, cell_output,
                    previous_attention_state[i], self._attention_layers[i]
                    if self._attention_layers else None)
                alignment_history = previous_alignment_history[i].write(
                    state.time, alignments) if self._alignment_history else ()

                all_attention_states.append(next_attention_state)
                all_alignments.append(alignments)
                all_attentions.append(attention)
                maybe_all_histories.append(alignment_history)

            # Building next state
            attention = array_ops.concat(all_attentions, 1)
            next_state = AttentionWrapperState(
                time=state.time + 1,
                cell_state=next_cell_state,
                attention=attention,
                attention_state=self._item_or_tuple(all_attention_states),
                alignments=self._item_or_tuple(all_alignments),
                alignment_history=self._item_or_tuple(maybe_all_histories))

            # Returning
            if self._output_attention:
                return attention, next_state
            return cell_output, next_state
예제 #13
0
    def __init__(self,
                 cell,
                 attention_mechanism,
                 attention_layer_size=None,
                 alignment_history=False,
                 cell_input_fn=None,
                 output_attention=True,
                 initial_cell_state=None,
                 name_or_scope='AttentionWrapper',
                 attention_layer=None):
        """ Construct the `AttentionWrapper`.

            :param cell: An instance of `RNNCell`.
            :param attention_mechanism: A list of `AttentionMechanism` instances or a singleinstance.
            :param attention_layer_size: A list of Python integers or a single Python integer,
                                         the depth of the attention (output) layer(s).
            :param alignment_history: Python boolean, whether to store alignment history from all time steps in the
                                      final output state
            :param cell_input_fn: (optional) A `callable`. The default is: concat([inputs, attention], axis=-1)
            :param output_attention: Python bool. If `True` (default), the output at each time step is the attn value.
            :param initial_cell_state: The initial state value to use for the cell when the user calls `zero_state()`.
            :param name_or_scope: String or VariableScope to use when creating ops.
            :param attention_layer: A list of `tf.layers.Layer` instances or a single `tf.layers.Layer` instance taking
                                    the context and cell output as inputs to generate attention at each time step.
                                    If None (default), use the context as attention at each time step.

            **NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in `AttentionWrapper`,
                     then you must ensure that:

            - The encoder output has been tiled to `beam_width` via `tf.contrib.seq2seq.tile_batch` (NOT `tf.tile`).
            - The `batch_size` argument passed to the `zero_state` method of this wrapper is equal to
                `true_batch_size * beam_width`.
            - The initial state created with `zero_state` above contains a `cell_state` value containing properly
                tiled final state from the encoder.
        """
        # pylint: disable=too-many-arguments
        self._name_or_scope = name_or_scope
        with variable_scope.variable_scope(name_or_scope, 'AttentionWrapper'):
            super(AttentionWrapper, self).__init__()
            rnn_cell_impl.assert_like_rnncell("cell", cell)

            # Attention mechanism
            if isinstance(attention_mechanism, (list, tuple)):
                self._is_multi = True
                attention_mechanisms = attention_mechanism
                for attn_mechanism in attention_mechanisms:
                    if not isinstance(attn_mechanism, AttentionMechanism):
                        raise TypeError(
                            'attention_mechanism must contain only instances of AttentionMechanism, saw '
                            'type: %s' % type(attn_mechanism).__name__)
            else:
                self._is_multi = False
                if not isinstance(attention_mechanism, AttentionMechanism):
                    raise TypeError(
                        'attention_mechanism must be an AttentionMechanism or list of multiple '
                        'AttentionMechanism instances, saw type: %s' %
                        type(attention_mechanism).__name__)
                attention_mechanisms = (attention_mechanism, )

            # Cell input function
            if cell_input_fn is None:
                cell_input_fn = lambda inputs, attention: array_ops.concat(
                    [inputs, attention], -1)
            else:
                if not callable(cell_input_fn):
                    raise TypeError(
                        'cell_input_fn must be callable, saw type: %s' %
                        type(cell_input_fn).__name__)

            # Attention layer size
            if attention_layer_size is not None and attention_layer is not None:
                raise ValueError(
                    'Only one of attention_layer_size and attention_layer should be set'
                )

            if attention_layer_size is not None:
                attention_layer_sizes = tuple(
                    attention_layer_size if isinstance(attention_layer_size, (
                        list, tuple)) else (attention_layer_size, ))
                if len(attention_layer_sizes) != len(attention_mechanisms):
                    raise ValueError(
                        'If provided, attention_layer_size must contain exactly one integer per '
                        'attention_mechanism, saw: %d vs %d' %
                        (len(attention_layer_sizes),
                         len(attention_mechanisms)))
                self._attention_layers = tuple(
                    core.Dense(attention_layer_size,
                               name='attention_layer',
                               use_bias=False,
                               dtype=attention_mechanisms[i].dtype) for i,
                    attention_layer_size in enumerate(attention_layer_sizes))
                self._attention_layer_size = sum(attention_layer_sizes)

            elif attention_layer is not None:
                self._attention_layers = tuple(attention_layer if isinstance(
                    attention_layer, (list, tuple)) else (attention_layer, ))
                if len(self._attention_layers) != len(attention_mechanisms):
                    raise ValueError(
                        'If provided, attention_layer must contain exactly one layer per '
                        'attention_mechanism, saw: %d vs %d' %
                        (len(self._attention_layers),
                         len(attention_mechanisms)))
                self._attention_layer_size = \
                    sum(tensor_shape.dimension_value(
                        layer.compute_output_shape([None,
                                                    cell.output_size
                                                    + tensor_shape.dimension_value(mechanism.values.shape[-1])])[-1])
                        for layer, mechanism in zip(self._attention_layers, attention_mechanisms))
            else:
                self._attention_layers = None
                self._attention_layer_size = sum(
                    tensor_shape.dimension_value(
                        attention_mechanism.values.shape[-1])
                    for attention_mechanism in attention_mechanisms)

            self._cell = cell
            self._attention_mechanisms = attention_mechanisms
            self._cell_input_fn = cell_input_fn
            self._output_attention = output_attention
            self._alignment_history = alignment_history

            if initial_cell_state is None:
                self._initial_cell_state = None
            else:
                final_state_tensor = nest.flatten(initial_cell_state)[-1]
                state_batch_size = (tensor_shape.dimension_value(
                    final_state_tensor.shape[0])
                                    or array_ops.shape(final_state_tensor)[0])
                error_message = (
                    'When constructing AttentionWrapper %s: ' % self._base_name
                    +
                    'Non-matching batch sizes between the memory (encoder output) and initial_cell_state. '
                    'Are you using the BeamSearchDecoder? You may need to tile your initial state via the '
                    'tf.contrib.seq2seq.tile_batch function with argument multiple=beam_width.'
                )

                with ops.control_dependencies(
                        self._batch_size_checks(state_batch_size,
                                                error_message)):
                    self._initial_cell_state = \
                        nest.map_structure(lambda state: array_ops.identity(state, name='check_initial_cell_state'),
                                           initial_cell_state)
예제 #14
0
 def training():
     """ Selecting training / teacher forcing """
     fill_op = gen_array_ops.fill([array_ops.shape(outputs)[0]], -1)
     with ops.control_dependencies([fill_op]):
         return array_ops.identity(fill_op)
예제 #15
0
    def _step(self, inputs, past_attns, time, feeder_cell, feeder_state):
        """ Performs the block operation on n-layers
            :param inputs: The tensor inputs (embedding of each word) - [batch, seq_len, emb_size]
            :param past_attns: The past attentions - [batch, nb_layers, 2, nb_heads. past_length, emb_size // nb_heads]
            :param time: A tensor representing the current time step
            :param feeder_cell: None or A feeder cell that returns a RNN cell output to use for conditioning
            :param feeder_state: None or the initial state of the feeder cell
            :param name: Name of the scope - To share weights between calls
            :return: A tuple consisting of:
                        1) The cell outputs - [batch, seq_len, emb_size]
                        2) The present attention - [batch, nb_layers, 2, nb_heads. seq_len, emb_size // nb_heads]
                        3) The new state of the feeder cell
        """
        with variable_scope.variable_scope(self._scope, default_name='step'):
            past_length = array_ops.shape(past_attns)[
                -2]  # How many past attention steps we have
            seq_len = array_ops.shape(inputs)[
                -2]  # How many steps are we computing for the current time
            emb_size = inputs.shape[-1].value  # The size of the embedding
            assert emb_size == self._emb_size, 'Expected an embedding size of %d' % self._emb_size

            # 1) Computing the word embedding of each token
            assert inputs.shape.ndims == 3, 'Expected [batch, seq_len, emb_size]'  # [bz, seq, emb]
            out_h = inputs

            # 2) Computing the position embedding of each token
            # If we know the context was padded, the effective past length is the context length + nb of time steps
            if self._past_seq_lengths is not None:
                past_length = gen_math_ops.minimum(
                    past_length,
                    self._past_seq_lengths + time)[:, None]  # [bz, 1]
            else:
                past_length = gen_array_ops.fill([self._batch_size, 1],
                                                 value=past_length)  # [bz, 1]
            step_ix = math_ops.range(seq_len)[None, :]  # [1, seq_len]
            token_positions = gen_math_ops.add(past_length,
                                               step_ix)  # [batch, seq_len]
            token_positions = gen_math_ops.minimum(
                self._position_emb_size - 1,
                token_positions)  # [batch, seq_len]
            h_pos = self._position_embedding_fn(
                token_positions)  # [bz, seq, emb]
            out_h = out_h + h_pos

            # 3) If we have a feeder cell, we also need to condition 'h' on it.
            next_feeder_state = feeder_state
            if feeder_cell is not None:
                assert feeder_state is not None, 'A feeder state is required if a feeder cell is provided.'
                assert inputs.shape[
                    1].value == 1, 'The seq dimension must be 1 to use a feeder_cell'
                feeder_outputs, next_feeder_state = feeder_cell(
                    array_ops.squeeze(inputs, axis=1), feeder_state)
                h_feed = feeder_outputs  # [bz, feeder_sz]
                if feeder_outputs.shape[-1].value != emb_size:
                    h_feed = core.Dense(emb_size,
                                        activation=None,
                                        name='h_feed')(h_feed)  # [bz, emb]
                h_feed = gen_array_ops.tile(h_feed[:, None, :],
                                            [1, seq_len, 1])  # [bz, seq, emb]
                out_h = out_h + h_feed

            # Transformer
            presents = []
            pasts = array_ops.unstack(
                past_attns,
                axis=1)  # list of [batch, 2, heads, past_len, head_sz]
            assert len(
                pasts
            ) == self._nb_layers, 'Expected the past attention to have %d layers.' % self._nb_layers

            for layer_ix, past_attn in enumerate(pasts):
                out_h, present = self._block(out_h, past_attn,
                                             'layer.%d' % layer_ix)
                presents += [present]
            presents = array_ops.stack(presents, axis=1)

            # Normalizing and returning
            cell_outputs = self._norm(out_h, 'norm_h')  # [batch, seq, emb]
            return cell_outputs, presents, next_feeder_state