Exemple #1
0
    def step(self, time, inputs, cell_state):
        """ Performs a step using the beam search cell
            :param time: The current time step (scalar)
            :param inputs: A (structure of) input tensors.
            :param state: A (structure of) state tensors and TensorArrays.
            :return: `(cell_outputs, next_cell_state)`.
        """
        raw_inputs = inputs
        inputs, candidates, candidates_emb = raw_inputs.inputs, raw_inputs.candidates, raw_inputs.candidates_emb

        inputs = nest.map_structure(lambda inp: self._merge_batch_beams(inp, depth_shape=inp.shape[2:]), inputs)
        cell_state = nest.map_structure(self._maybe_merge_batch_beams, cell_state, self._cell.state_size)
        cell_outputs, next_cell_state = self._cell(inputs, cell_state)                  # [batch * beam, out_sz]
        next_cell_state = nest.map_structure(self._maybe_split_batch_beams, next_cell_state, self._cell.state_size)

        # Splitting outputs and adding a bias dimension
        # cell_outputs is [batch, beam, cand_emb_size + 1]
        cell_outputs = self._output_layer(cell_outputs) if self._output_layer is not None else cell_outputs
        cell_outputs = nest.map_structure(lambda out: self._split_batch_beams(out, out.shape[1:]), cell_outputs)
        cell_outputs = array_ops.pad(cell_outputs, [(0, 0), (0, 0), (0, 1)], constant_values=1.)

        # Computing candidates
        # cell_outputs is reshaped to   [batch, beam,        1, cand_emb_size + 1]
        # candidates_emb is reshaped to [batch,    1, max_cand, cand_emb_size + 1]
        # output_mask is                [batch,    1, max_cand]
        # cell_outputs is finally       [batch, beam, max_cand]
        cell_outputs = math_ops.reduce_sum(array_ops.expand_dims(cell_outputs, axis=2)
                                           * array_ops.expand_dims(candidates_emb, axis=1), axis=-1)
        output_mask = math_ops.cast(array_ops.expand_dims(gen_math_ops.greater(candidates, 0), axis=1), dtypes.float32)
        cell_outputs = gen_math_ops.add(cell_outputs, (1. - output_mask) * LARGE_NEGATIVE)

        # Returning
        return cell_outputs, next_cell_state
Exemple #2
0
    def step(self, time, inputs, cell_state):
        """ Performs a step using the beam search cell
            :param time: The current time step (scalar)
            :param inputs: A (structure of) input tensors.
            :param state: A (structure of) state tensors and TensorArrays.
            :return: `(cell_outputs, next_cell_state)`.
        """
        raw_inputs = inputs
        inputs, output_mask = raw_inputs.inputs, raw_inputs.mask
        inputs = nest.map_structure(
            lambda inp: self._merge_batch_beams(inp, depth_shape=inp.shape[2:]
                                                ), inputs)
        cell_state = nest.map_structure(self._maybe_merge_batch_beams,
                                        cell_state, self._cell.state_size)
        cell_outputs, next_cell_state = self._cell(
            inputs, cell_state)  # [batch * beam, out_sz]
        next_cell_state = nest.map_structure(self._maybe_split_batch_beams,
                                             next_cell_state,
                                             self._cell.state_size)

        # Splitting outputs and applying mask
        # cell_outputs is [batch, beam, vocab_size]
        cell_outputs = self._output_layer(
            cell_outputs) if self._output_layer is not None else cell_outputs
        cell_outputs = nest.map_structure(
            lambda out: self._split_batch_beams(out, out.shape[1:]),
            cell_outputs)
        cell_outputs = gen_math_ops.add(cell_outputs,
                                        (1. - output_mask) * LARGE_NEGATIVE)

        # Returning
        return cell_outputs, next_cell_state
Exemple #3
0
    def __init__(self, cell, order_embedding, candidate_embedding, candidates, sequence_length, initial_state,
                 beam_width, input_layer=None, output_layer=None, time_major=False):
        """ Initialize the CustomBeamHelper
            :param cell: An `RNNCell` instance.
            :param order_embedding: The order embedding vector  - Size: (batch, ord_emb_size)
            :param candidate_embedding: The candidate embedding vector - Size: (batch, cand_emb_size)
            :param candidates: The candidates at each time step -- Size: (batch, nb_cand, max_candidates)
            :param sequence_length: The length of each sequence (batch,)
            :param initial_state: A (possibly nested tuple of...) tensors and TensorArrays.
            :param beam_width: Python integer, the number of beams.
            :param input_layer: Optional. A layer to apply on the inputs
            :param output_layer: Optional. An instance of `tf.layers.Layer`, i.e., `tf.layers.Dense`. Optional layer
                                 to apply to the RNN output prior to storing the result or sampling.
            :param time_major: If true indicates that the first dimension is time, otherwise it is batch size.
        """
        # pylint: disable=super-init-not-called,too-many-arguments
        rnn_cell_impl.assert_like_rnncell('cell', cell)                                                                 # pylint: disable=protected-access
        assert isinstance(beam_width, int), 'beam_width should be a Python integer'

        self._sequence_length = ops.convert_to_tensor(sequence_length, name='sequence_length')
        if self._sequence_length.get_shape().ndims != 1:
            raise ValueError("Expected vector for sequence_length. Shape: %s" % self._sequence_length.get_shape())

        candidates = ops.convert_to_tensor(candidates, name='candidates')
        candidates = nest.map_structure(_transpose_batch_time, candidates) if not time_major else candidates

        self._cell = cell
        self._order_embedding_fn = _get_embedding_fn(order_embedding)
        self._candidate_embedding_fn = _get_embedding_fn(candidate_embedding)
        self._candidate_tas = nest.map_structure(_unstack_ta, candidates)
        self._input_layer = input_layer if input_layer is not None else lambda x: x
        self._output_layer = output_layer

        self._input_size = order_embedding.shape[-1]
        if input_layer is not None:
            self._input_size = self._input_layer.compute_output_shape([None, self._input_size])[-1]

        self._batch_size = array_ops.size(sequence_length)
        self._start_tokens = gen_array_ops.fill([self._batch_size * beam_width], GO_ID)
        self._end_token = -1
        self._beam_width = beam_width
        self._initial_cell_state = nest.map_structure(self._maybe_split_batch_beams,
                                                      initial_state,
                                                      self._cell.state_size)
        self._finished = array_ops.one_hot(array_ops.zeros([self._batch_size], dtype=dtypes.int32),
                                           depth=self._beam_width,
                                           on_value=False,
                                           off_value=True,
                                           dtype=dtypes.bool)

        # Compute input shape
        self._zero_inputs = \
            CandidateInputs(inputs=
                            array_ops.zeros_like(self._split_batch_beams(
                                self._input_layer(self._order_embedding_fn(self._start_tokens)),
                                self._input_size)),
                            candidates=array_ops.zeros_like(candidates[0, :]),
                            candidates_emb=array_ops.zeros_like(self._candidate_embedding_fn(candidates[0, :])))
Exemple #4
0
 def output_dtype(self):
     # Assume the dtype of the cell is the output_size structure
     # containing the input_state's first component's dtype.
     # Return that structure and the sample_ids_dtype from the helper.
     dtype = nest.flatten(self._initial_state)[0].dtype
     if self.extract_state:
         return BasicDecoderWithStateOutput(
             nest.map_structure(lambda _: dtype, self._rnn_output_size()),
             dtype, self._helper.sample_ids_dtype)
     return seq2seq.BasicDecoderOutput(
         nest.map_structure(lambda _: dtype, self._rnn_output_size()),
         self._helper.sample_ids_dtype)
Exemple #5
0
    def output_size(self):
        """ Returns the size of the RNN output """
        size = self._cell.output_size
        if self._output_layer is None:
            return size

        # To use layer's compute_output_shape, we need to convert the RNNCell's output_size entries into shapes
        # with an unknown batch size.  We then pass this through the layer's compute_output_shape and read off
        # all but the first (batch) dimensions to get the output size of the rnn with the layer applied to the top.
        output_shape_with_unknown_batch = \
            nest.map_structure(lambda shape: tensor_shape.TensorShape([None]).concatenate(shape), size)
        layer_output_shape = self._output_layer.compute_output_shape(output_shape_with_unknown_batch)
        return nest.map_structure(lambda shape: shape[1:], layer_output_shape)
Exemple #6
0
    def __init__(self, decoder_type, inputs, order_embedding, candidate_embedding, sequence_length, candidates,
                 input_layer=None, time_major=False, softmax_temperature=None, seed=None, name=None):
        """ Constructor
            :param decoder_type: An uint8 representing TRAINING_DECODER, GREEDY_DECODER, or SAMPLE_DECODER
            :param inputs: The decoder input (b, dec_len)
            :param order_embedding: The order embedding vector
            :param candidate_embedding: The candidate embedding vector
            :param sequence_length: The length of each input (b,)
            :param candidates: The candidates at each time step -- Size: (b, nb_cand, max_candidates)
            :param input_layer: Optional. A layer to apply on the inputs
            :param time_major: If true indicates that the first dimension is time, otherwise it is batch size
            :param softmax_temperature: Optional. Softmax temperature. None, scalar, or size: (batch_size,)
            :param seed: Optional. The sampling seed
            :param name: Optional scope name.
        """
        # pylint: disable=too-many-arguments
        with ops.name_scope(name, "CustomHelper", [inputs, sequence_length, order_embedding, candidate_embedding]):
            inputs = ops.convert_to_tensor(inputs, name="inputs")
            candidates = ops.convert_to_tensor(candidates, name="candidates")
            self._inputs = inputs
            self._order_embedding_fn = _get_embedding_fn(order_embedding)
            self._candidate_embedding_fn = _get_embedding_fn(candidate_embedding)
            if not time_major:
                inputs = nest.map_structure(_transpose_batch_time, inputs)
                candidates = nest.map_structure(_transpose_batch_time, candidates)
            self._input_tas = nest.map_structure(_unstack_ta, inputs)
            self._candidate_tas = nest.map_structure(_unstack_ta, candidates)
            self._decoder_type = decoder_type
            self._sequence_length = ops.convert_to_tensor(sequence_length, name="sequence_length")
            if self._sequence_length.get_shape().ndims != 1:
                raise ValueError("Expected vector for sequence_length. Shape: %s" % self._sequence_length.get_shape())
            self._input_layer = input_layer if input_layer is not None else lambda x: x
            self._batch_size = array_ops.size(sequence_length)
            self._start_inputs = gen_array_ops.fill([self._batch_size], GO_ID)
            self._softmax_temperature = softmax_temperature
            self._seed = seed

            # Compute input shape
            self._zero_inputs = \
                CandidateInputs(inputs=
                                array_ops.zeros_like(self._input_layer(self._order_embedding_fn(self._start_inputs))),
                                candidates=array_ops.zeros_like(candidates[0, :]),
                                candidates_emb=array_ops.zeros_like(self._candidate_embedding_fn(candidates[0, :])))

            # Preventing div by zero
            # Adding an extra dim to the matrix, so we can broadcast with the outputs shape
            if softmax_temperature is not None:
                self._softmax_temperature = gen_math_ops.maximum(1e-10, self._softmax_temperature)
                if self._softmax_temperature.get_shape().ndims == 1:
                    self._softmax_temperature = self._softmax_temperature[:, None]
        def body(time, outputs_ta, state, inputs, finished, sequence_lengths):
            """ Internal while_loop body. """
            (next_outputs, decoder_state, next_inputs,
             decoder_finished) = decoder.step(time, inputs, state)
            if decoder.tracks_own_finished:
                next_finished = decoder_finished
            else:
                next_finished = gen_math_ops.logical_or(
                    decoder_finished, finished)
            next_sequence_lengths = array_ops.where(
                gen_math_ops.logical_not(finished),
                gen_array_ops.fill(array_ops.shape(sequence_lengths),
                                   time + 1), sequence_lengths)

            nest.assert_same_structure(state, decoder_state)
            nest.assert_same_structure(outputs_ta, next_outputs)
            nest.assert_same_structure(inputs, next_inputs)

            # Zero out output values past finish
            if impute_finished:
                emit = nest.map_structure(
                    lambda out, zero: array_ops.where(finished, zero, out),
                    next_outputs, zero_outputs)
            else:
                emit = next_outputs

            # Copy through states past finish
            def _maybe_copy_state(new, cur):
                # TensorArrays, multiple dynamic dims, and scalar states get passed through.
                if isinstance(cur, tensor_array_ops.TensorArray):
                    pass_through = True
                elif None in new.shape.as_list()[1:]:
                    pass_through = True
                else:
                    new.set_shape(cur.shape)
                    pass_through = (new.shape.ndims == 0)
                return new if pass_through else array_ops.where(
                    finished, cur, new)

            if impute_finished:
                next_state = nest.map_structure(_maybe_copy_state,
                                                decoder_state, state)
            else:
                next_state = decoder_state

            outputs_ta = nest.map_structure(
                lambda ta, out: ta.write(time, out), outputs_ta, emit)
            return (time + 1, outputs_ta, next_state, next_inputs,
                    next_finished, next_sequence_lengths)
Exemple #8
0
    def __init__(self,
                 cell,
                 concat_inputs,
                 mask_inputs=None,
                 embedding=None,
                 name=None):
        """ Constructs an ArrayConcatWrapper

            If embedding is provided, the concat_inputs is expected to be [batch, time]
            If embedding is not provided, the concat_inputs is expected to be [batch, time, input_size]

            mask_inputs of True will mask (zero-out) the given input (or embedded input)

            :param cell: An instance of `RNNCell`.
            :param concat_inputs: The inputs to concatenate [batch, time] or [batch, time, input_size]
            :param mask_inputs: Optional. Boolean [batch, time] that indicates if the concat_inputs is to be masked
            :param embedding: Optional. Embedding fn or embedding vector to embed the concat_inputs at each time step
            :param name: name: Name to use when creating ops.
        """
        # pylint: disable=too-many-arguments
        # Initializing RNN Cell
        super(ArrayConcatWrapper, self).__init__(name=name)
        rnn_cell_impl.assert_like_rnncell('cell', cell)

        # Setting values
        self._cell = cell
        self._cell_input_fn = lambda input_1, input_2: array_ops.concat(
            [input_1, input_2], axis=-1)
        self._embedding_fn = _get_embedding_fn(embedding)
        self._mask_inputs_ta = None

        # Converting mask inputs to a tensor array
        if mask_inputs is not None:
            mask_inputs = nest.map_structure(_transpose_batch_time,
                                             mask_inputs)
            self._mask_inputs_ta = nest.map_structure(
                _unstack_ta, mask_inputs)  # [time, batch]

        # Converting concat_inputs to a tensor array
        concat_inputs = nest.map_structure(_transpose_batch_time,
                                           concat_inputs)
        self._concat_inputs_ta = nest.map_structure(
            _unstack_ta, concat_inputs)  # [time, batch] / [t, b, inp_sz]
Exemple #9
0
    def clone(self, **kwargs):
        """ Clone this object, overriding components provided by kwargs. """
        def with_same_shape(old, new):
            """Check and set new tensor's shape."""
            if isinstance(old, ops.Tensor) and isinstance(new, ops.Tensor):
                return contrib_framework.with_same_shape(old, new)
            return new

        return nest.map_structure(
            with_same_shape, self,
            super(SelfAttentionWrapperState, self)._replace(**kwargs))
Exemple #10
0
    def zero_state(self, batch_size, dtype):
        """ Return an initial (zero) state tuple for this `AttentionWrapper`.
            :param batch_size: `0D` integer tensor: the batch size.
            :param dtype: The internal state data type.
            :return: AttentionWrapperState` tuple containing zeroed out tensors and, possibly, empty `TensorArrays`.
        """
        with ops.name_scope(type(self).__name__ + 'ZeroState',
                            values=[batch_size]):
            if self._initial_cell_state is not None:
                cell_state = self._initial_cell_state
            else:
                cell_state = self._cell.zero_state(batch_size, dtype)

            error_message = (
                'When calling zero_state of AttentionWrapper %s: ' %
                self._base_name +
                'Non-matching batch sizes between the memory encoder output) and the requested batch '
                'size. Are you using the BeamSearchDecoder? If so, make sure your encoder output has been '
                'tiled to beam_width via tf.contrib.seq2seq.tile_batch, and the batch_size= argument '
                'passed to zero_state is batch_size * beam_width.')
            with ops.control_dependencies(
                    self._batch_size_checks(batch_size, error_message)):
                cell_state = nest.map_structure(
                    lambda state: array_ops.identity(
                        state, name='checked_cell_state'), cell_state)
            initial_alignments = [
                attention_mechanism.initial_alignments(batch_size, dtype)
                for attention_mechanism in self._attention_mechanisms
            ]
            return AttentionWrapperState(
                cell_state=cell_state,
                time=array_ops.zeros([], dtype=dtypes.int32),
                attention=_zero_state_tensors(self._attention_layer_size,
                                              batch_size, dtype),
                alignments=self._item_or_tuple(initial_alignments),
                attention_state=self._item_or_tuple(
                    attention_mechanism.initial_state(batch_size, dtype)
                    for attention_mechanism in self._attention_mechanisms),
                alignment_history=self._item_or_tuple(
                    tensor_array_ops.TensorArray(dtype,
                                                 size=0,
                                                 dynamic_size=True,
                                                 element_shape=alignment.shape)
                    if self._alignment_history else ()
                    for alignment in initial_alignments))
Exemple #11
0
    def __init__(self,
                 cell,
                 memory,
                 alignments,
                 sequence_length,
                 probability_fn=None,
                 score_mask_value=None,
                 attention_layer_size=None,
                 cell_input_fn=None,
                 output_attention=False,
                 name=None):
        """ Constructs an AttentionWrapper with static alignments (attention weights)

            :param cell: An instance of `RNNCell`.
            :param memory: The memory to query [batch_size, memory_time, memory_size]
            :param alignments: A tensor of probabilities of shape [batch_size, time_steps, memory_time]
            :param sequence_length: Sequence lengths for the batch entries in memory. Size (b,)
            :param probability_fn: A `callable`.  Converts the score to probabilities.  The default is @{tf.nn.softmax}.
            :param score_mask_value:  The mask value for score before passing into `probability_fn`. Default is -inf.
            :param attention_layer_size: The size of the attention layer. Uses the context if None.
            :param cell_input_fn: (optional) A `callable` to aggregate attention.
                                  Default: `lambda inputs, attention: array_ops.concat([inputs, attention], -1)`.
            :param output_attention: If true, outputs the attention, if False outputs the cell output.
            :param name: name: Name to use when creating ops.
        """
        # pylint: disable=too-many-arguments
        # Initializing RNN Cell
        super(StaticAttentionWrapper, self).__init__(name=name)
        rnn_cell_impl.assert_like_rnncell('cell', cell)

        # Setting values
        self._cell = cell
        self._memory = memory
        self._attention_layer_size = attention_layer_size
        self._output_attention = output_attention
        self._memory_time = alignments.get_shape()[-1].value
        self._memory_size = memory.get_shape()[-1].value
        self._sequence_length = sequence_length

        # Validating attention layer size
        if self._attention_layer_size is None:
            self._attention_layer_size = self._memory_size

        # Validating cell_input_fn
        if cell_input_fn is None:
            cell_input_fn = lambda inputs, attention: array_ops.concat(
                [inputs, attention], -1)
        else:
            if not callable(cell_input_fn):
                raise TypeError(
                    'cell_input_fn must be callable, saw type: %s' %
                    type(cell_input_fn).__name__)
        self._cell_input_fn = cell_input_fn

        # Probability Function
        if probability_fn is None:
            probability_fn = nn_ops.softmax
        if score_mask_value is None:
            score_mask_value = dtypes.as_dtype(
                self._memory.dtype).as_numpy_dtype(-np.inf)
        self._probability_fn = lambda score, _: probability_fn(
            _maybe_mask_score(score, sequence_length, score_mask_value), _)

        # Storing alignments as TA
        # Padding with 1 additional zero, to prevent error on read(0)
        alignments = array_ops.pad(alignments, [(0, 0), (0, 1), (0, 0)])
        alignments = nest.map_structure(
            _transpose_batch_time,
            alignments)  # (max_time + 1, b, memory_time)
        self._alignments_ta = nest.map_structure(
            _unstack_ta, alignments)  # [time_step + 1, batch, memory_time]
        self._initial_alignment = self._alignments_ta.read(0)
        self._initial_attention = self._compute_attention(
            self._initial_alignment, self._memory)[0]

        # Storing zero inputs
        batch_size = array_ops.shape(memory)[0]
        self._zero_cell_output = array_ops.zeros(
            [batch_size, cell.output_size])
        self._zero_attention = array_ops.zeros(
            [batch_size, self._attention_layer_size])
        self._zero_state = self.zero_state(batch_size, dtypes.float32)
        self._zero_alignment = array_ops.zeros_like(self._initial_alignment)
Exemple #12
0
    def __init__(self,
                 cell,
                 attention_mechanism,
                 attention_layer_size=None,
                 alignment_history=False,
                 cell_input_fn=None,
                 output_attention=True,
                 initial_cell_state=None,
                 name_or_scope='AttentionWrapper',
                 attention_layer=None):
        """ Construct the `AttentionWrapper`.

            :param cell: An instance of `RNNCell`.
            :param attention_mechanism: A list of `AttentionMechanism` instances or a singleinstance.
            :param attention_layer_size: A list of Python integers or a single Python integer,
                                         the depth of the attention (output) layer(s).
            :param alignment_history: Python boolean, whether to store alignment history from all time steps in the
                                      final output state
            :param cell_input_fn: (optional) A `callable`. The default is: concat([inputs, attention], axis=-1)
            :param output_attention: Python bool. If `True` (default), the output at each time step is the attn value.
            :param initial_cell_state: The initial state value to use for the cell when the user calls `zero_state()`.
            :param name_or_scope: String or VariableScope to use when creating ops.
            :param attention_layer: A list of `tf.layers.Layer` instances or a single `tf.layers.Layer` instance taking
                                    the context and cell output as inputs to generate attention at each time step.
                                    If None (default), use the context as attention at each time step.

            **NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in `AttentionWrapper`,
                     then you must ensure that:

            - The encoder output has been tiled to `beam_width` via `tf.contrib.seq2seq.tile_batch` (NOT `tf.tile`).
            - The `batch_size` argument passed to the `zero_state` method of this wrapper is equal to
                `true_batch_size * beam_width`.
            - The initial state created with `zero_state` above contains a `cell_state` value containing properly
                tiled final state from the encoder.
        """
        # pylint: disable=too-many-arguments
        self._name_or_scope = name_or_scope
        with variable_scope.variable_scope(name_or_scope, 'AttentionWrapper'):
            super(AttentionWrapper, self).__init__()
            rnn_cell_impl.assert_like_rnncell("cell", cell)

            # Attention mechanism
            if isinstance(attention_mechanism, (list, tuple)):
                self._is_multi = True
                attention_mechanisms = attention_mechanism
                for attn_mechanism in attention_mechanisms:
                    if not isinstance(attn_mechanism, AttentionMechanism):
                        raise TypeError(
                            'attention_mechanism must contain only instances of AttentionMechanism, saw '
                            'type: %s' % type(attn_mechanism).__name__)
            else:
                self._is_multi = False
                if not isinstance(attention_mechanism, AttentionMechanism):
                    raise TypeError(
                        'attention_mechanism must be an AttentionMechanism or list of multiple '
                        'AttentionMechanism instances, saw type: %s' %
                        type(attention_mechanism).__name__)
                attention_mechanisms = (attention_mechanism, )

            # Cell input function
            if cell_input_fn is None:
                cell_input_fn = lambda inputs, attention: array_ops.concat(
                    [inputs, attention], -1)
            else:
                if not callable(cell_input_fn):
                    raise TypeError(
                        'cell_input_fn must be callable, saw type: %s' %
                        type(cell_input_fn).__name__)

            # Attention layer size
            if attention_layer_size is not None and attention_layer is not None:
                raise ValueError(
                    'Only one of attention_layer_size and attention_layer should be set'
                )

            if attention_layer_size is not None:
                attention_layer_sizes = tuple(
                    attention_layer_size if isinstance(attention_layer_size, (
                        list, tuple)) else (attention_layer_size, ))
                if len(attention_layer_sizes) != len(attention_mechanisms):
                    raise ValueError(
                        'If provided, attention_layer_size must contain exactly one integer per '
                        'attention_mechanism, saw: %d vs %d' %
                        (len(attention_layer_sizes),
                         len(attention_mechanisms)))
                self._attention_layers = tuple(
                    core.Dense(attention_layer_size,
                               name='attention_layer',
                               use_bias=False,
                               dtype=attention_mechanisms[i].dtype) for i,
                    attention_layer_size in enumerate(attention_layer_sizes))
                self._attention_layer_size = sum(attention_layer_sizes)

            elif attention_layer is not None:
                self._attention_layers = tuple(attention_layer if isinstance(
                    attention_layer, (list, tuple)) else (attention_layer, ))
                if len(self._attention_layers) != len(attention_mechanisms):
                    raise ValueError(
                        'If provided, attention_layer must contain exactly one layer per '
                        'attention_mechanism, saw: %d vs %d' %
                        (len(self._attention_layers),
                         len(attention_mechanisms)))
                self._attention_layer_size = \
                    sum(tensor_shape.dimension_value(
                        layer.compute_output_shape([None,
                                                    cell.output_size
                                                    + tensor_shape.dimension_value(mechanism.values.shape[-1])])[-1])
                        for layer, mechanism in zip(self._attention_layers, attention_mechanisms))
            else:
                self._attention_layers = None
                self._attention_layer_size = sum(
                    tensor_shape.dimension_value(
                        attention_mechanism.values.shape[-1])
                    for attention_mechanism in attention_mechanisms)

            self._cell = cell
            self._attention_mechanisms = attention_mechanisms
            self._cell_input_fn = cell_input_fn
            self._output_attention = output_attention
            self._alignment_history = alignment_history

            if initial_cell_state is None:
                self._initial_cell_state = None
            else:
                final_state_tensor = nest.flatten(initial_cell_state)[-1]
                state_batch_size = (tensor_shape.dimension_value(
                    final_state_tensor.shape[0])
                                    or array_ops.shape(final_state_tensor)[0])
                error_message = (
                    'When constructing AttentionWrapper %s: ' % self._base_name
                    +
                    'Non-matching batch sizes between the memory (encoder output) and initial_cell_state. '
                    'Are you using the BeamSearchDecoder? You may need to tile your initial state via the '
                    'tf.contrib.seq2seq.tile_batch function with argument multiple=beam_width.'
                )

                with ops.control_dependencies(
                        self._batch_size_checks(state_batch_size,
                                                error_message)):
                    self._initial_cell_state = \
                        nest.map_structure(lambda state: array_ops.identity(state, name='check_initial_cell_state'),
                                           initial_cell_state)
Exemple #13
0
    def __init__(self,
                 cell,
                 embedding,
                 mask,
                 sequence_length,
                 initial_state,
                 beam_width,
                 input_layer=None,
                 output_layer=None,
                 time_major=False):
        """ Initialize the CustomBeamHelper
            :param cell: An `RNNCell` instance.
            :param embedding: The embedding vector
            :param mask: [SparseTensor] Mask to apply at each time step -- Size: (b, dec_len, vocab_size, vocab_size)
            :param sequence_length: The length of each input (b,)
            :param initial_state: A (possibly nested tuple of...) tensors and TensorArrays.
            :param beam_width: Python integer, the number of beams.
            :param input_layer: Optional. A layer to apply on the inputs
            :param output_layer: Optional. An instance of `tf.layers.Layer`, i.e., `tf.layers.Dense`. Optional layer
                                 to apply to the RNN output prior to storing the result or sampling.
            :param time_major: If true indicates that the first dimension is time, otherwise it is batch size.
        """
        # pylint: disable=super-init-not-called,too-many-arguments
        rnn_cell_impl.assert_like_rnncell('cell', cell)  # pylint: disable=protected-access
        assert isinstance(mask,
                          SparseTensor), 'The mask must be a SparseTensor'
        assert isinstance(beam_width,
                          int), 'beam_width should be a Python integer'

        self._sequence_length = ops.convert_to_tensor(sequence_length,
                                                      name='sequence_length')
        if self._sequence_length.get_shape().ndims != 1:
            raise ValueError("Expected vector for sequence_length. Shape: %s" %
                             self._sequence_length.get_shape())

        self._cell = cell
        self._embedding_fn = _get_embedding_fn(embedding)
        self._mask = mask
        self._time_major = time_major
        self.vocab_size = VOCABULARY_SIZE
        self._input_layer = input_layer if input_layer is not None else lambda x: x
        self._output_layer = output_layer

        self._input_size = embedding.shape[-1]
        if input_layer is not None:
            self._input_size = self._input_layer.compute_output_shape(
                [None, self._input_size])[-1]

        self._batch_size = array_ops.size(sequence_length)
        self._start_tokens = gen_array_ops.fill(
            [self._batch_size * beam_width], GO_ID)
        self._end_token = -1
        self._beam_width = beam_width
        self._initial_cell_state = nest.map_structure(
            self._maybe_split_batch_beams, initial_state,
            self._cell.state_size)
        self._finished = array_ops.one_hot(array_ops.zeros([self._batch_size],
                                                           dtype=dtypes.int32),
                                           depth=self._beam_width,
                                           on_value=False,
                                           off_value=True,
                                           dtype=dtypes.bool)

        # zero_mask is (batch, beam, vocab_size)
        self._zero_mask = _slice_mask(self._mask,
                                      slicing=[-1, 0, GO_ID, -1],
                                      squeeze=True,
                                      time_major=self._time_major)
        self._zero_mask = gen_array_ops.tile(
            array_ops.expand_dims(self._zero_mask, axis=1),
            [1, self._beam_width, 1])
        self._zero_inputs = \
            MaskedInputs(
                inputs=array_ops.zeros_like(
                    self._split_batch_beams(
                        self._input_layer(self._embedding_fn(self._start_tokens)), self._input_size)),
                mask=self._zero_mask)
Exemple #14
0
    def __init__(self,
                 decoder_type,
                 inputs,
                 embedding,
                 sequence_length,
                 mask,
                 input_layer=None,
                 time_major=False,
                 softmax_temperature=None,
                 seed=None,
                 name=None):
        """ Constructor
            :param decoder_type: An uint8 representing TRAINING_DECODER, GREEDY_DECODER, or SAMPLE_DECODER
            :param inputs: The decoder input (b, dec_len)
            :param embedding: The embedding vector
            :param sequence_length: The length of each input (b,)
            :param mask: [SparseTensor] Mask to apply at each time step -- Size: (b, dec_len, vocab_size, vocab_size)
            :param input_layer: Optional. A layer to apply on the inputs
            :param time_major: If true indicates that the first dimension is time, otherwise it is batch size
            :param softmax_temperature: Optional. Softmax temperature. None or size: (batch_size,)
            :param seed: Optional. The sampling seed
            :param name: Optional scope name.
        """
        # pylint: disable=too-many-arguments
        with ops.name_scope(name, "CustomHelper",
                            [inputs, sequence_length, embedding]):
            assert isinstance(mask,
                              SparseTensor), 'The mask must be a SparseTensor'
            inputs = ops.convert_to_tensor(inputs, name="inputs")
            self._inputs = inputs
            self._mask = mask
            self._time_major = time_major
            self._embedding_fn = embedding if callable(
                embedding) else lambda ids: embedding_lookup(embedding, ids)
            if not time_major:
                inputs = nest.map_structure(_transpose_batch_time, inputs)
            self._input_tas = nest.map_structure(_unstack_ta, inputs)
            self._decoder_type = decoder_type
            self._sequence_length = ops.convert_to_tensor(
                sequence_length, name="sequence_length")
            if self._sequence_length.get_shape().ndims != 1:
                raise ValueError(
                    "Expected vector for sequence_length. Shape: %s" %
                    self._sequence_length.get_shape())
            self._input_layer = input_layer if callable(
                input_layer) else lambda x: x
            self._batch_size = array_ops.size(sequence_length)
            self._start_inputs = gen_array_ops.fill([self._batch_size], GO_ID)
            self._softmax_temperature = softmax_temperature
            self._seed = seed
            self.vocab_size = VOCABULARY_SIZE
            self._zero_inputs = \
                MaskedInputs(inputs=array_ops.zeros_like(self._input_layer(self._embedding_fn(self._start_inputs))),
                             mask=_slice_mask(self._mask,
                                              slicing=[-1, 0, GO_ID, -1],
                                              squeeze=True,
                                              time_major=self._time_major))

            # Preventing div by zero
            # Adding an extra dim to the matrix, so we can broadcast with the outputs shape
            if softmax_temperature is not None:
                self._softmax_temperature = gen_math_ops.maximum(
                    1e-10, self._softmax_temperature)
                if self._softmax_temperature.get_shape().ndims == 1:
                    self._softmax_temperature = self._softmax_temperature[:,
                                                                          None]
 def _invariants(structure):
     """ Returns the invariants of a structure """
     return nest.map_structure(_inv_shape, structure)
def dynamic_decode(decoder,
                   output_time_major=False,
                   impute_finished=False,
                   maximum_iterations=None,
                   parallel_iterations=32,
                   invariants_map=None,
                   swap_memory=False,
                   scope=None):
    """ Performs dynamic decoding with `decoder`.
        :param decoder: A `Decoder` instance.
        :param output_time_major: If True, outputs [time, batch, ...], otherwise outputs [batch, time, ...]
        :param impute_finished: If true, finished states are copied through the end of the game
        :param maximum_iterations: Int or None. The maximum number of steps (otherwise decode until it's done)
        :param parallel_iterations: Argument passed to tf.while_loop
        :param invariants_map: Optional. Dictionary of tensor path (in initial_state) to its shape invariant.
        :param swap_memory: Argument passed to `tf.while_loop`.
        :param scope: Optional variable scope to use.
        :return: A tuple of 1) final_outputs, 2) final_state, 3) final_sequence_length
    """
    if not isinstance(decoder, seq2seq.Decoder):
        raise TypeError('Expected decoder to be type Decoder, but saw: %s' %
                        type(decoder))

    with variable_scope.variable_scope(scope, 'decoder') as varscope:

        # Determine context types.
        ctxt = ops.get_default_graph()._get_control_flow_context()  # pylint: disable=protected-access
        is_xla = control_flow_util.GetContainingXLAContext(ctxt) is not None
        in_while_loop = control_flow_util.GetContainingWhileContext(
            ctxt) is not None

        # Properly cache variable values inside the while_loop.
        # Don't set a caching device when running in a loop, since it is possible that train steps could be wrapped
        # in a tf.while_loop. In that scenario caching prevents forward computations in loop iterations from re-reading
        # the updated weights.
        if not context.executing_eagerly() and not in_while_loop:
            if varscope.caching_device is None:
                varscope.set_caching_device(lambda op: op.device)

        # Setting maximum iterations
        if maximum_iterations is not None:
            maximum_iterations = ops.convert_to_tensor(
                maximum_iterations,
                dtype=dtypes.int32,
                name="maximum_iterations")
            if maximum_iterations.get_shape().ndims != 0:
                raise ValueError('maximum_iterations must be a scalar')

        def _inv_shape(maybe_ta):
            """ Returns the invariatns shape """
            if isinstance(maybe_ta, tensor_array_ops.TensorArray):
                return maybe_ta.flow.shape
            return maybe_ta.shape

        def _invariants(structure):
            """ Returns the invariants of a structure """
            return nest.map_structure(_inv_shape, structure)

        def _map_invariants(structure):
            """ Returns the invariants of a structure, but replaces the invariant using the value in invariants_map """
            return nest.map_structure_with_paths(
                lambda path, tensor:
                (invariants_map or {}).get(path, _inv_shape(tensor)),
                structure)

        # Initializing decoder
        initial_finished, initial_inputs, initial_state = decoder.initialize()
        zero_outputs = _create_zero_outputs(decoder.output_size,
                                            decoder.output_dtype,
                                            decoder.batch_size)

        if is_xla and maximum_iterations is None:
            raise ValueError(
                'maximum_iterations is required for XLA compilation.')
        if maximum_iterations is not None:
            initial_finished = gen_math_ops.logical_or(initial_finished,
                                                       maximum_iterations <= 0)
        initial_sequence_lengths = array_ops.zeros_like(initial_finished,
                                                        dtype=dtypes.int32)
        initial_time = constant_op.constant(0, dtype=dtypes.int32)

        # Creating initial output TA
        def _shape(batch_size, from_shape):
            """ Returns the batch_size concatenated with the from_shape """
            if (not isinstance(from_shape, tensor_shape.TensorShape)
                    or from_shape.ndims == 0):
                return tensor_shape.TensorShape(None)
            batch_size = tensor_util.constant_value(
                ops.convert_to_tensor(batch_size, name='batch_size'))
            return tensor_shape.TensorShape([batch_size
                                             ]).concatenate(from_shape)

        dynamic_size = maximum_iterations is None or not is_xla

        def _create_ta(shape, dtype):
            """ Creates a tensor array"""
            return tensor_array_ops.TensorArray(
                dtype=dtype,
                size=0 if dynamic_size else maximum_iterations,
                dynamic_size=dynamic_size,
                element_shape=_shape(decoder.batch_size, shape))

        initial_outputs_ta = nest.map_structure(_create_ta,
                                                decoder.output_size,
                                                decoder.output_dtype)

        def condition(unused_time, unused_outputs_ta, unused_state,
                      unused_inputs, finished, unused_sequence_lengths):
            """ While loop condition"""
            return gen_math_ops.logical_not(math_ops.reduce_all(finished))

        def body(time, outputs_ta, state, inputs, finished, sequence_lengths):
            """ Internal while_loop body. """
            (next_outputs, decoder_state, next_inputs,
             decoder_finished) = decoder.step(time, inputs, state)
            if decoder.tracks_own_finished:
                next_finished = decoder_finished
            else:
                next_finished = gen_math_ops.logical_or(
                    decoder_finished, finished)
            next_sequence_lengths = array_ops.where(
                gen_math_ops.logical_not(finished),
                gen_array_ops.fill(array_ops.shape(sequence_lengths),
                                   time + 1), sequence_lengths)

            nest.assert_same_structure(state, decoder_state)
            nest.assert_same_structure(outputs_ta, next_outputs)
            nest.assert_same_structure(inputs, next_inputs)

            # Zero out output values past finish
            if impute_finished:
                emit = nest.map_structure(
                    lambda out, zero: array_ops.where(finished, zero, out),
                    next_outputs, zero_outputs)
            else:
                emit = next_outputs

            # Copy through states past finish
            def _maybe_copy_state(new, cur):
                # TensorArrays, multiple dynamic dims, and scalar states get passed through.
                if isinstance(cur, tensor_array_ops.TensorArray):
                    pass_through = True
                elif None in new.shape.as_list()[1:]:
                    pass_through = True
                else:
                    new.set_shape(cur.shape)
                    pass_through = (new.shape.ndims == 0)
                return new if pass_through else array_ops.where(
                    finished, cur, new)

            if impute_finished:
                next_state = nest.map_structure(_maybe_copy_state,
                                                decoder_state, state)
            else:
                next_state = decoder_state

            outputs_ta = nest.map_structure(
                lambda ta, out: ta.write(time, out), outputs_ta, emit)
            return (time + 1, outputs_ta, next_state, next_inputs,
                    next_finished, next_sequence_lengths)

        res = control_flow_ops.while_loop(
            condition,
            body,
            loop_vars=(initial_time, initial_outputs_ta, initial_state,
                       initial_inputs, initial_finished,
                       initial_sequence_lengths),
            shape_invariants=(_invariants(initial_time),
                              _invariants(initial_outputs_ta),
                              _map_invariants(initial_state),
                              _invariants(initial_inputs),
                              _invariants(initial_finished),
                              _invariants(initial_sequence_lengths)),
            parallel_iterations=parallel_iterations,
            maximum_iterations=maximum_iterations,
            swap_memory=swap_memory)

        final_outputs_ta = res[1]
        final_state = res[2]
        final_sequence_lengths = res[5]

        final_outputs = nest.map_structure(lambda ta: ta.stack(),
                                           final_outputs_ta)

        try:
            final_outputs, final_state = decoder.finalize(
                final_outputs, final_state, final_sequence_lengths)
        except NotImplementedError:
            pass

        if not output_time_major:
            final_outputs = nest.map_structure(_transpose_batch_time,
                                               final_outputs)

    return final_outputs, final_state, final_sequence_lengths