Exemple #1
0
 def zero_state(self, batch_size, dtype):
     """ Return an initial (zero) state tuple for this `AttentionWrapper`.
         :param batch_size: `0D` integer tensor: the batch size.
         :param dtype: The internal state data type.
         :return: An `SelfAttentionWrapperState` tuple containing zeroed out tensors.
     """
     with ops.name_scope(type(self).__name__ + 'ZeroState',
                         values=[batch_size]):
         # Using batch_size * 0, rather than just 0 to have a dynamic dimension
         initial_cell_state = self._cell.zero_state(batch_size, dtype)
         initial_memory = array_ops.zeros(
             [batch_size, batch_size * 0, self._memory_size],
             dtype=self._dtype)
         return SelfAttentionWrapperState(cell_state=initial_cell_state,
                                          time=array_ops.zeros(
                                              [], dtype=dtypes.int32),
                                          memory=initial_memory)
Exemple #2
0
 def get_zero_memory_and_attn():
     """ Time = 0, we don't concatenate to memory and attention is all 0. """
     next_memory = state.memory
     next_attention = array_ops.zeros(
         [batch_size, self._attention_layer_size], dtype=inputs.dtype)
     with ops.control_dependencies([next_memory, next_attention]):
         return array_ops.identity(next_memory), array_ops.identity(
             next_attention)
Exemple #3
0
    def __init__(self, cell, order_embedding, candidate_embedding, candidates, sequence_length, initial_state,
                 beam_width, input_layer=None, output_layer=None, time_major=False):
        """ Initialize the CustomBeamHelper
            :param cell: An `RNNCell` instance.
            :param order_embedding: The order embedding vector  - Size: (batch, ord_emb_size)
            :param candidate_embedding: The candidate embedding vector - Size: (batch, cand_emb_size)
            :param candidates: The candidates at each time step -- Size: (batch, nb_cand, max_candidates)
            :param sequence_length: The length of each sequence (batch,)
            :param initial_state: A (possibly nested tuple of...) tensors and TensorArrays.
            :param beam_width: Python integer, the number of beams.
            :param input_layer: Optional. A layer to apply on the inputs
            :param output_layer: Optional. An instance of `tf.layers.Layer`, i.e., `tf.layers.Dense`. Optional layer
                                 to apply to the RNN output prior to storing the result or sampling.
            :param time_major: If true indicates that the first dimension is time, otherwise it is batch size.
        """
        # pylint: disable=super-init-not-called,too-many-arguments
        rnn_cell_impl.assert_like_rnncell('cell', cell)                                                                 # pylint: disable=protected-access
        assert isinstance(beam_width, int), 'beam_width should be a Python integer'

        self._sequence_length = ops.convert_to_tensor(sequence_length, name='sequence_length')
        if self._sequence_length.get_shape().ndims != 1:
            raise ValueError("Expected vector for sequence_length. Shape: %s" % self._sequence_length.get_shape())

        candidates = ops.convert_to_tensor(candidates, name='candidates')
        candidates = nest.map_structure(_transpose_batch_time, candidates) if not time_major else candidates

        self._cell = cell
        self._order_embedding_fn = _get_embedding_fn(order_embedding)
        self._candidate_embedding_fn = _get_embedding_fn(candidate_embedding)
        self._candidate_tas = nest.map_structure(_unstack_ta, candidates)
        self._input_layer = input_layer if input_layer is not None else lambda x: x
        self._output_layer = output_layer

        self._input_size = order_embedding.shape[-1]
        if input_layer is not None:
            self._input_size = self._input_layer.compute_output_shape([None, self._input_size])[-1]

        self._batch_size = array_ops.size(sequence_length)
        self._start_tokens = gen_array_ops.fill([self._batch_size * beam_width], GO_ID)
        self._end_token = -1
        self._beam_width = beam_width
        self._initial_cell_state = nest.map_structure(self._maybe_split_batch_beams,
                                                      initial_state,
                                                      self._cell.state_size)
        self._finished = array_ops.one_hot(array_ops.zeros([self._batch_size], dtype=dtypes.int32),
                                           depth=self._beam_width,
                                           on_value=False,
                                           off_value=True,
                                           dtype=dtypes.bool)

        # Compute input shape
        self._zero_inputs = \
            CandidateInputs(inputs=
                            array_ops.zeros_like(self._split_batch_beams(
                                self._input_layer(self._order_embedding_fn(self._start_tokens)),
                                self._input_size)),
                            candidates=array_ops.zeros_like(candidates[0, :]),
                            candidates_emb=array_ops.zeros_like(self._candidate_embedding_fn(candidates[0, :])))
Exemple #4
0
 def zero_state(self, batch_size, dtype):
     """ Return an initial (zero) state tuple for this `IdentityCell`.
         :param batch_size: `0D` integer tensor: the batch size.
         :param dtype: The internal state data type.
         :return: A zeroed out scalar representing the initial state of the cell.
     """
     with ops.name_scope(type(self).__name__ + 'ZeroState',
                         values=[batch_size]):
         return array_ops.zeros([], dtype=dtypes.int32)
Exemple #5
0
    def zero_state(self, batch_size, dtype):
        """ Return an initial (zero) state tuple for this `IdentityCell`.
            :param batch_size: `0D` integer tensor: the batch size.
            :param dtype: The internal state data type.
            :return: A zeroed out scalar representing the initial state of the cell.
        """
        with ops.name_scope(type(self).__name__ + 'ZeroState',
                            values=[batch_size]):
            if self._feeder_cell is None:
                feeder_init_state = array_ops.zeros([], dtype=dtype)
            elif self._feeder_init_state is not None:
                feeder_init_state = self._feeder_init_state
            else:
                feeder_init_state = self._feeder_cell.zero_state(
                    batch_size, dtype)

            # Empty past attentions
            if self._past_attns is None:
                head_size = self._emb_size // self._nb_heads
                past_attns_shape = [
                    batch_size, self._nb_layers, 2, self._nb_heads,
                    0 * batch_size, head_size
                ]
                self._past_attns = array_ops.zeros(past_attns_shape,
                                                   dtype=dtypes.float32)

            # No Context - Returning a zero past attention
            if self._context is None:
                return TransformerCellState(past_attentions=self._past_attns,
                                            feeder_state=feeder_init_state,
                                            time=array_ops.zeros(
                                                [], dtype=dtypes.int32))

            # Context provided - Computing attention by running a single block step
            _, present_attns, _ = self._step(
                inputs=self._context_word_embedding_fn(self._context),
                past_attns=self._past_attns,
                time=0,
                feeder_cell=None,
                feeder_state=None)
            return TransformerCellState(past_attentions=present_attns,
                                        feeder_state=feeder_init_state,
                                        time=array_ops.zeros(
                                            [], dtype=dtypes.int32))
Exemple #6
0
 def zero_state(self, batch_size, dtype):
     """ Return an initial (zero) state tuple for this `ArrayConcatWrapper`.
         :param batch_size: `0D` integer tensor: the batch size.
         :param dtype: The internal state data type.
         :return: An `ArrayConcatWrapperState` tuple containing zeroed out tensors and, possibly, empty TA objects.
     """
     with ops.name_scope(type(self).__name__ + 'ZeroState',
                         values=[batch_size]):
         return ArrayConcatWrapperState(
             cell_state=self._cell.zero_state(batch_size, dtype),
             time=array_ops.zeros([], dtype=dtypes.int32))
Exemple #7
0
    def zero_state(self, batch_size, dtype):
        """ Return an initial (zero) state tuple for this `AttentionWrapper`.
            :param batch_size: `0D` integer tensor: the batch size.
            :param dtype: The internal state data type.
            :return: AttentionWrapperState` tuple containing zeroed out tensors and, possibly, empty `TensorArrays`.
        """
        with ops.name_scope(type(self).__name__ + 'ZeroState',
                            values=[batch_size]):
            if self._initial_cell_state is not None:
                cell_state = self._initial_cell_state
            else:
                cell_state = self._cell.zero_state(batch_size, dtype)

            error_message = (
                'When calling zero_state of AttentionWrapper %s: ' %
                self._base_name +
                'Non-matching batch sizes between the memory encoder output) and the requested batch '
                'size. Are you using the BeamSearchDecoder? If so, make sure your encoder output has been '
                'tiled to beam_width via tf.contrib.seq2seq.tile_batch, and the batch_size= argument '
                'passed to zero_state is batch_size * beam_width.')
            with ops.control_dependencies(
                    self._batch_size_checks(batch_size, error_message)):
                cell_state = nest.map_structure(
                    lambda state: array_ops.identity(
                        state, name='checked_cell_state'), cell_state)
            initial_alignments = [
                attention_mechanism.initial_alignments(batch_size, dtype)
                for attention_mechanism in self._attention_mechanisms
            ]
            return AttentionWrapperState(
                cell_state=cell_state,
                time=array_ops.zeros([], dtype=dtypes.int32),
                attention=_zero_state_tensors(self._attention_layer_size,
                                              batch_size, dtype),
                alignments=self._item_or_tuple(initial_alignments),
                attention_state=self._item_or_tuple(
                    attention_mechanism.initial_state(batch_size, dtype)
                    for attention_mechanism in self._attention_mechanisms),
                alignment_history=self._item_or_tuple(
                    tensor_array_ops.TensorArray(dtype,
                                                 size=0,
                                                 dynamic_size=True,
                                                 element_shape=alignment.shape)
                    if self._alignment_history else ()
                    for alignment in initial_alignments))
Exemple #8
0
    def __init__(self,
                 cell,
                 memory,
                 alignments,
                 sequence_length,
                 probability_fn=None,
                 score_mask_value=None,
                 attention_layer_size=None,
                 cell_input_fn=None,
                 output_attention=False,
                 name=None):
        """ Constructs an AttentionWrapper with static alignments (attention weights)

            :param cell: An instance of `RNNCell`.
            :param memory: The memory to query [batch_size, memory_time, memory_size]
            :param alignments: A tensor of probabilities of shape [batch_size, time_steps, memory_time]
            :param sequence_length: Sequence lengths for the batch entries in memory. Size (b,)
            :param probability_fn: A `callable`.  Converts the score to probabilities.  The default is @{tf.nn.softmax}.
            :param score_mask_value:  The mask value for score before passing into `probability_fn`. Default is -inf.
            :param attention_layer_size: The size of the attention layer. Uses the context if None.
            :param cell_input_fn: (optional) A `callable` to aggregate attention.
                                  Default: `lambda inputs, attention: array_ops.concat([inputs, attention], -1)`.
            :param output_attention: If true, outputs the attention, if False outputs the cell output.
            :param name: name: Name to use when creating ops.
        """
        # pylint: disable=too-many-arguments
        # Initializing RNN Cell
        super(StaticAttentionWrapper, self).__init__(name=name)
        rnn_cell_impl.assert_like_rnncell('cell', cell)

        # Setting values
        self._cell = cell
        self._memory = memory
        self._attention_layer_size = attention_layer_size
        self._output_attention = output_attention
        self._memory_time = alignments.get_shape()[-1].value
        self._memory_size = memory.get_shape()[-1].value
        self._sequence_length = sequence_length

        # Validating attention layer size
        if self._attention_layer_size is None:
            self._attention_layer_size = self._memory_size

        # Validating cell_input_fn
        if cell_input_fn is None:
            cell_input_fn = lambda inputs, attention: array_ops.concat(
                [inputs, attention], -1)
        else:
            if not callable(cell_input_fn):
                raise TypeError(
                    'cell_input_fn must be callable, saw type: %s' %
                    type(cell_input_fn).__name__)
        self._cell_input_fn = cell_input_fn

        # Probability Function
        if probability_fn is None:
            probability_fn = nn_ops.softmax
        if score_mask_value is None:
            score_mask_value = dtypes.as_dtype(
                self._memory.dtype).as_numpy_dtype(-np.inf)
        self._probability_fn = lambda score, _: probability_fn(
            _maybe_mask_score(score, sequence_length, score_mask_value), _)

        # Storing alignments as TA
        # Padding with 1 additional zero, to prevent error on read(0)
        alignments = array_ops.pad(alignments, [(0, 0), (0, 1), (0, 0)])
        alignments = nest.map_structure(
            _transpose_batch_time,
            alignments)  # (max_time + 1, b, memory_time)
        self._alignments_ta = nest.map_structure(
            _unstack_ta, alignments)  # [time_step + 1, batch, memory_time]
        self._initial_alignment = self._alignments_ta.read(0)
        self._initial_attention = self._compute_attention(
            self._initial_alignment, self._memory)[0]

        # Storing zero inputs
        batch_size = array_ops.shape(memory)[0]
        self._zero_cell_output = array_ops.zeros(
            [batch_size, cell.output_size])
        self._zero_attention = array_ops.zeros(
            [batch_size, self._attention_layer_size])
        self._zero_state = self.zero_state(batch_size, dtypes.float32)
        self._zero_alignment = array_ops.zeros_like(self._initial_alignment)
Exemple #9
0
    def __init__(self,
                 cell,
                 embedding,
                 mask,
                 sequence_length,
                 initial_state,
                 beam_width,
                 input_layer=None,
                 output_layer=None,
                 time_major=False):
        """ Initialize the CustomBeamHelper
            :param cell: An `RNNCell` instance.
            :param embedding: The embedding vector
            :param mask: [SparseTensor] Mask to apply at each time step -- Size: (b, dec_len, vocab_size, vocab_size)
            :param sequence_length: The length of each input (b,)
            :param initial_state: A (possibly nested tuple of...) tensors and TensorArrays.
            :param beam_width: Python integer, the number of beams.
            :param input_layer: Optional. A layer to apply on the inputs
            :param output_layer: Optional. An instance of `tf.layers.Layer`, i.e., `tf.layers.Dense`. Optional layer
                                 to apply to the RNN output prior to storing the result or sampling.
            :param time_major: If true indicates that the first dimension is time, otherwise it is batch size.
        """
        # pylint: disable=super-init-not-called,too-many-arguments
        rnn_cell_impl.assert_like_rnncell('cell', cell)  # pylint: disable=protected-access
        assert isinstance(mask,
                          SparseTensor), 'The mask must be a SparseTensor'
        assert isinstance(beam_width,
                          int), 'beam_width should be a Python integer'

        self._sequence_length = ops.convert_to_tensor(sequence_length,
                                                      name='sequence_length')
        if self._sequence_length.get_shape().ndims != 1:
            raise ValueError("Expected vector for sequence_length. Shape: %s" %
                             self._sequence_length.get_shape())

        self._cell = cell
        self._embedding_fn = _get_embedding_fn(embedding)
        self._mask = mask
        self._time_major = time_major
        self.vocab_size = VOCABULARY_SIZE
        self._input_layer = input_layer if input_layer is not None else lambda x: x
        self._output_layer = output_layer

        self._input_size = embedding.shape[-1]
        if input_layer is not None:
            self._input_size = self._input_layer.compute_output_shape(
                [None, self._input_size])[-1]

        self._batch_size = array_ops.size(sequence_length)
        self._start_tokens = gen_array_ops.fill(
            [self._batch_size * beam_width], GO_ID)
        self._end_token = -1
        self._beam_width = beam_width
        self._initial_cell_state = nest.map_structure(
            self._maybe_split_batch_beams, initial_state,
            self._cell.state_size)
        self._finished = array_ops.one_hot(array_ops.zeros([self._batch_size],
                                                           dtype=dtypes.int32),
                                           depth=self._beam_width,
                                           on_value=False,
                                           off_value=True,
                                           dtype=dtypes.bool)

        # zero_mask is (batch, beam, vocab_size)
        self._zero_mask = _slice_mask(self._mask,
                                      slicing=[-1, 0, GO_ID, -1],
                                      squeeze=True,
                                      time_major=self._time_major)
        self._zero_mask = gen_array_ops.tile(
            array_ops.expand_dims(self._zero_mask, axis=1),
            [1, self._beam_width, 1])
        self._zero_inputs = \
            MaskedInputs(
                inputs=array_ops.zeros_like(
                    self._split_batch_beams(
                        self._input_layer(self._embedding_fn(self._start_tokens)), self._input_size)),
                mask=self._zero_mask)