コード例 #1
0
ファイル: transformer.py プロジェクト: zhanpengfang/research
    def _conv1d(inputs, output_dim, name=None):
        """ Performs 1d convolution
            :param inputs: The tensor inputs - [axis_0, ..., axis_n-1, input_dim]
            :param output_dim: The size of the 1d convoltuion
            :param name: Name of the scope - To share weights between calls
            :return: The tensors after the convolution - [axis_0, ..., axis_n-1, output_dim]
        """
        with variable_scope.variable_scope(name):
            input_dims = [
                array_ops.shape(inputs)[axis]
                if dim.value is None else dim.value
                for axis, dim in enumerate(inputs.shape)
            ]
            input_prefix_dims, input_last_dim = input_dims[:-1], input_dims[-1]
            weight = variable_scope.get_variable(
                'w',
                shape=[input_last_dim, output_dim],
                initializer=init_ops.random_normal_initializer(0.02))
            beta = variable_scope.get_variable(
                'b',
                shape=[output_dim],
                initializer=init_ops.constant_initializer(0.))

            inputs = gen_array_ops.reshape(
                inputs, [-1, input_last_dim])  # [B, input_last_dim]
            outputs = math_ops.matmul(inputs, weight) + beta  # [B, output_dim]
            return gen_array_ops.reshape(outputs, input_prefix_dims +
                                         [output_dim])  # [..., output_dim]
コード例 #2
0
ファイル: transformer.py プロジェクト: zhanpengfang/research
 def _linear(self, inputs, proj_dim, name):
     """ Computes a linear unit inside a full block - Projects to 'proj_dim' and back to 'input_dim'
         :param inputs: The inputs tensor -  [axis_0, ..., axis_n-1, input_dim]
         :param proj_dim: The dimension of the projection
         :param name: Name of the scope - To share weights between calls
         :return: A tensor of shape -  [axis_0, ..., axis_n-1, input_dim]
     """
     with variable_scope.variable_scope(name):
         input_dim = inputs.shape[-1].value
         output_h1 = gelu(self._conv1d(inputs, proj_dim, 'mlp_fc1'))
         output_h2 = self._conv1d(output_h1, input_dim, 'mlp_fc2')
         return output_h2
コード例 #3
0
    def __init__(self,
                 num_units,
                 memory,
                 memory_sequence_length=None,
                 normalize=False,
                 probability_fn=None,
                 score_mask_value=None,
                 dtype=None,
                 name_or_scope='BahdanauAttention'):
        """ Construct the Attention mechanism.
            :param num_units: The depth of the query mechanism.
            :param memory: The memory to query; usually the output of an RNN encoder. This tensor should be
                           shaped `[batch_size, max_time, ...]`.
            :param memory_sequence_length: (optional): Sequence lengths for the batch entries in memory. If provided,
                                           the memory tensor rows are masked with zeros for values past the respective
                                           sequence lengths.
            :param normalize: Python boolean. Whether to normalize the energy term.
            :param probability_fn: (optional) A `callable`. Converts the score to probabilities. The default is
                                   @{tf.nn.softmax}. Other options include @{tf.contrib.seq2seq.hardmax} and
                                   @{tf.contrib.sparsemax.sparsemax}.
                                   Its signature should be: `probabilities = probability_fn(score)`.
            :param score_mask_value: (optional): The mask value for score before passing into `probability_fn`.
                                     The default is -inf. Only used if `memory_sequence_length` is not None.
            :param dtype: The data type for the query and memory layers of the attention mechanism.
            :param name_or_scope: String or VariableScope to use when creating ops.
        """
        # pylint: disable=too-many-arguments
        if probability_fn is None:
            probability_fn = nn_ops.softmax
        if dtype is None:
            dtype = dtypes.float32
        wrapped_probability_fn = lambda score, _: probability_fn(score)

        self._num_units = num_units
        self._normalize = normalize
        self._name_or_scope = name_or_scope

        with variable_scope.variable_scope(name_or_scope,
                                           default_name='BahdanauAttention'):
            super(BahdanauAttention,
                  self).__init__(query_layer=core.Dense(num_units,
                                                        name='query_layer',
                                                        use_bias=False,
                                                        dtype=dtype),
                                 memory_layer=core.Dense(num_units,
                                                         name='memory_layer',
                                                         use_bias=False,
                                                         dtype=dtype),
                                 memory=memory,
                                 probability_fn=wrapped_probability_fn,
                                 memory_sequence_length=memory_sequence_length,
                                 score_mask_value=score_mask_value)
コード例 #4
0
 def __call__(self, query, state):
     """ Score the query based on the keys and values.
         :param query: Tensor of dtype matching `self.values` and shape `[batch_size, query_depth]`.
         :param state: Tensor of dtype matching `self.values` and shape `[batch_size, alignments_size]`
                       (`alignments_size` is memory's `max_time`).
         :return: Tensor of dtype matching `self.values` and shape `[batch_size, alignments_size]`
                  (`alignments_size` is memory's `max_time`).
     """
     with variable_scope.variable_scope(self._name_or_scope,
                                        'bahdanau_attention', [query]):
         processed_query = self.query_layer(
             query) if self.query_layer else query
         score = _bahdanau_score(processed_query, self._keys,
                                 self._normalize)
     alignments = self._probability_fn(score, state)
     next_state = alignments
     return alignments, next_state
コード例 #5
0
ファイル: transformer.py プロジェクト: zhanpengfang/research
 def _block(self, inputs, past_attn, name):
     """ Computes a transformer block
         :param inputs: The inputs tensor - [batch, seq_len, emb_dim]
         :param past_attn: The past attention - [batch, 2, nb_heads, emb_size // nb_heads]
         :param name: Name of the scope - To share weights between calls
         :return: A tuple consisting of:
                 1) The output of the transformer block - [batch, seq_len, emb_dim]
                 2) The present attention - [batch, 2, nb_heads, 1, emb_size // nb_heads]
     """
     with variable_scope.variable_scope(name):
         input_dim = inputs.shape[-1].value
         h_out = inputs
         h_attn, present_attn = self._attn(self._norm(h_out, 'block_norm1'),
                                           input_dim, past_attn,
                                           'block_attn')
         h_out = h_out + h_attn
         h_mlp = self._linear(self._norm(h_out, 'block_norm2'),
                              input_dim * 4, 'block_mlp')
         h_out = h_out + h_mlp
         return h_out, present_attn
コード例 #6
0
ファイル: transformer.py プロジェクト: zhanpengfang/research
 def _norm(inputs, name, axis=-1):
     """ Applies normalization to the input tensor by normalizing to mean=0, std_dev=1, then applying a gamma, beta
         :param inputs: The tensor inputs to normalize
         :param name: Name of the scope - To share weights between calls
         :param axis: Axis to normalize. Defaults to last one.
         :return: A tensor of the same shape as inputs, but normalized and transformed
     """
     with variable_scope.variable_scope(name):
         axis_dim = inputs.shape[axis].value
         gamma = variable_scope.get_variable(
             'gamma', [axis_dim],
             initializer=init_ops.constant_initializer(1.))
         beta = variable_scope.get_variable(
             'beta', [axis_dim],
             initializer=init_ops.constant_initializer(0.))
         mean = math_ops.reduce_mean(inputs, axis=axis, keepdims=True)
         var = math_ops.reduce_mean(gen_math_ops.square(inputs - mean),
                                    axis=axis,
                                    keepdims=True)
         norm_inputs = (inputs - mean) * gen_math_ops.rsqrt(var + 1e-8)
         outputs = gamma * norm_inputs + beta
         return outputs
コード例 #7
0
    def call(self, inputs, state):
        """ Performs a step of attention-wrapped RNN.

            1) Mix the `inputs` and previous step's `attention` output via `cell_input_fn`.
            2) Call the wrapped `cell` with this input and its previous state.
            3) Score the cell's output with `attention_mechanism`.
            4) Calculate the alignments by passing the score through the `normalizer`.
            5) Calculate the context vector as the inner product between the alignments and the attention_mechanism's
               values (memory).
            6) Calculate the attention output by concatenating the cell output and context through the attention
               layer (a linear layer with `attention_layer_size` outputs).

            :param inputs: (Possibly nested tuple of) Tensor, the input at this time step.
            :param state: An instance of `AttentionWrapperState` containing tensors from the previous time step.
            :return: A tuple `(attention_or_cell_output, next_state)`, where:
            - `attention_or_cell_output` depending on `output_attention`.
            - `next_state` is an instance of `AttentionWrapperState` containing the state calculated at this time step.
        """
        # pylint: disable=arguments-differ
        if not isinstance(state, AttentionWrapperState):
            raise TypeError(
                'Expected state to be instance of AttentionWrapperState. Rcvd %s instead. '
                % type(state))

        # Step 1: Calculate the true inputs to the cell based on the previous attention value.
        cell_inputs = self._cell_input_fn(inputs, state.attention)
        cell_state = state.cell_state
        cell_output, next_cell_state = self._cell(cell_inputs, cell_state)

        cell_batch_size = tensor_shape.dimension_value(
            cell_output.shape[0]) or array_ops.shape(cell_output)[0]
        error_message = (
            'When applying AttentionWrapper %s: ' % self.name +
            'Non-matching batch sizes between '
            'the memory (encoder output) and the query (decoder output). Are you using the '
            'BeamSearchDecoder? You may need to tile your memory input via the tf.contrib.seq2seq.'
            'tile_batch function with argument multiple=beam_width.')

        with variable_scope.variable_scope(self._name_or_scope,
                                           'AttentionWrapper',
                                           [inputs, state]):

            with ops.control_dependencies(
                    self._batch_size_checks(cell_batch_size, error_message)):
                cell_output = array_ops.identity(cell_output,
                                                 name='checked_cell_output')

            if self._is_multi:
                previous_attention_state = state.attention_state
                previous_alignment_history = state.alignment_history
            else:
                previous_attention_state = [state.attention_state]
                previous_alignment_history = [state.alignment_history]

            # Computing attention
            all_alignments = []
            all_attentions = []
            all_attention_states = []
            maybe_all_histories = []
            for i, attention_mechanism in enumerate(
                    self._attention_mechanisms):
                attention, alignments, next_attention_state = _compute_attention(
                    attention_mechanism, cell_output,
                    previous_attention_state[i], self._attention_layers[i]
                    if self._attention_layers else None)
                alignment_history = previous_alignment_history[i].write(
                    state.time, alignments) if self._alignment_history else ()

                all_attention_states.append(next_attention_state)
                all_alignments.append(alignments)
                all_attentions.append(attention)
                maybe_all_histories.append(alignment_history)

            # Building next state
            attention = array_ops.concat(all_attentions, 1)
            next_state = AttentionWrapperState(
                time=state.time + 1,
                cell_state=next_cell_state,
                attention=attention,
                attention_state=self._item_or_tuple(all_attention_states),
                alignments=self._item_or_tuple(all_alignments),
                alignment_history=self._item_or_tuple(maybe_all_histories))

            # Returning
            if self._output_attention:
                return attention, next_state
            return cell_output, next_state
コード例 #8
0
    def __init__(self,
                 cell,
                 attention_mechanism,
                 attention_layer_size=None,
                 alignment_history=False,
                 cell_input_fn=None,
                 output_attention=True,
                 initial_cell_state=None,
                 name_or_scope='AttentionWrapper',
                 attention_layer=None):
        """ Construct the `AttentionWrapper`.

            :param cell: An instance of `RNNCell`.
            :param attention_mechanism: A list of `AttentionMechanism` instances or a singleinstance.
            :param attention_layer_size: A list of Python integers or a single Python integer,
                                         the depth of the attention (output) layer(s).
            :param alignment_history: Python boolean, whether to store alignment history from all time steps in the
                                      final output state
            :param cell_input_fn: (optional) A `callable`. The default is: concat([inputs, attention], axis=-1)
            :param output_attention: Python bool. If `True` (default), the output at each time step is the attn value.
            :param initial_cell_state: The initial state value to use for the cell when the user calls `zero_state()`.
            :param name_or_scope: String or VariableScope to use when creating ops.
            :param attention_layer: A list of `tf.layers.Layer` instances or a single `tf.layers.Layer` instance taking
                                    the context and cell output as inputs to generate attention at each time step.
                                    If None (default), use the context as attention at each time step.

            **NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in `AttentionWrapper`,
                     then you must ensure that:

            - The encoder output has been tiled to `beam_width` via `tf.contrib.seq2seq.tile_batch` (NOT `tf.tile`).
            - The `batch_size` argument passed to the `zero_state` method of this wrapper is equal to
                `true_batch_size * beam_width`.
            - The initial state created with `zero_state` above contains a `cell_state` value containing properly
                tiled final state from the encoder.
        """
        # pylint: disable=too-many-arguments
        self._name_or_scope = name_or_scope
        with variable_scope.variable_scope(name_or_scope, 'AttentionWrapper'):
            super(AttentionWrapper, self).__init__()
            rnn_cell_impl.assert_like_rnncell("cell", cell)

            # Attention mechanism
            if isinstance(attention_mechanism, (list, tuple)):
                self._is_multi = True
                attention_mechanisms = attention_mechanism
                for attn_mechanism in attention_mechanisms:
                    if not isinstance(attn_mechanism, AttentionMechanism):
                        raise TypeError(
                            'attention_mechanism must contain only instances of AttentionMechanism, saw '
                            'type: %s' % type(attn_mechanism).__name__)
            else:
                self._is_multi = False
                if not isinstance(attention_mechanism, AttentionMechanism):
                    raise TypeError(
                        'attention_mechanism must be an AttentionMechanism or list of multiple '
                        'AttentionMechanism instances, saw type: %s' %
                        type(attention_mechanism).__name__)
                attention_mechanisms = (attention_mechanism, )

            # Cell input function
            if cell_input_fn is None:
                cell_input_fn = lambda inputs, attention: array_ops.concat(
                    [inputs, attention], -1)
            else:
                if not callable(cell_input_fn):
                    raise TypeError(
                        'cell_input_fn must be callable, saw type: %s' %
                        type(cell_input_fn).__name__)

            # Attention layer size
            if attention_layer_size is not None and attention_layer is not None:
                raise ValueError(
                    'Only one of attention_layer_size and attention_layer should be set'
                )

            if attention_layer_size is not None:
                attention_layer_sizes = tuple(
                    attention_layer_size if isinstance(attention_layer_size, (
                        list, tuple)) else (attention_layer_size, ))
                if len(attention_layer_sizes) != len(attention_mechanisms):
                    raise ValueError(
                        'If provided, attention_layer_size must contain exactly one integer per '
                        'attention_mechanism, saw: %d vs %d' %
                        (len(attention_layer_sizes),
                         len(attention_mechanisms)))
                self._attention_layers = tuple(
                    core.Dense(attention_layer_size,
                               name='attention_layer',
                               use_bias=False,
                               dtype=attention_mechanisms[i].dtype) for i,
                    attention_layer_size in enumerate(attention_layer_sizes))
                self._attention_layer_size = sum(attention_layer_sizes)

            elif attention_layer is not None:
                self._attention_layers = tuple(attention_layer if isinstance(
                    attention_layer, (list, tuple)) else (attention_layer, ))
                if len(self._attention_layers) != len(attention_mechanisms):
                    raise ValueError(
                        'If provided, attention_layer must contain exactly one layer per '
                        'attention_mechanism, saw: %d vs %d' %
                        (len(self._attention_layers),
                         len(attention_mechanisms)))
                self._attention_layer_size = \
                    sum(tensor_shape.dimension_value(
                        layer.compute_output_shape([None,
                                                    cell.output_size
                                                    + tensor_shape.dimension_value(mechanism.values.shape[-1])])[-1])
                        for layer, mechanism in zip(self._attention_layers, attention_mechanisms))
            else:
                self._attention_layers = None
                self._attention_layer_size = sum(
                    tensor_shape.dimension_value(
                        attention_mechanism.values.shape[-1])
                    for attention_mechanism in attention_mechanisms)

            self._cell = cell
            self._attention_mechanisms = attention_mechanisms
            self._cell_input_fn = cell_input_fn
            self._output_attention = output_attention
            self._alignment_history = alignment_history

            if initial_cell_state is None:
                self._initial_cell_state = None
            else:
                final_state_tensor = nest.flatten(initial_cell_state)[-1]
                state_batch_size = (tensor_shape.dimension_value(
                    final_state_tensor.shape[0])
                                    or array_ops.shape(final_state_tensor)[0])
                error_message = (
                    'When constructing AttentionWrapper %s: ' % self._base_name
                    +
                    'Non-matching batch sizes between the memory (encoder output) and initial_cell_state. '
                    'Are you using the BeamSearchDecoder? You may need to tile your initial state via the '
                    'tf.contrib.seq2seq.tile_batch function with argument multiple=beam_width.'
                )

                with ops.control_dependencies(
                        self._batch_size_checks(state_batch_size,
                                                error_message)):
                    self._initial_cell_state = \
                        nest.map_structure(lambda state: array_ops.identity(state, name='check_initial_cell_state'),
                                           initial_cell_state)
コード例 #9
0
ファイル: transformer.py プロジェクト: zhanpengfang/research
    def _attn(self, inputs, attn_dim, past_attn, name):
        """ Performs multi-head attention inside a transformer block
            :param inputs: The tensor inputs - [batch, seq_len, emb_size]
            :param attn_dim: The dimension of the attention (and output)
            :param past_attn: The past attention - [batch, 2, nb_heads. seq_len, emb_size // nb_heads]
            :param name: Name of the scope - To share weights between calls
            :return: A tuple consisting of:
                    1) The output of the attention - [batch, seq_len, attn_dim]
                    2) The present attention - [batch, 2, nb_heads, seq_len, emb_size // nb_heads]
        """
        assert inputs.shape.ndims == 3, 'Expected [batch, seq_len, emb_size]'
        with variable_scope.variable_scope(name):

            # Computing the query, key, and value vectors
            query_keys_values = self._conv1d(
                inputs, 3 * attn_dim,
                'attn_fc1')  # [batch, seq_len, 3 * attn_dim]
            query, keys, values = array_ops.split(
                query_keys_values, 3, axis=-1)  # 3x [batch, seq_len, attn_dim]

            # Splitting into nb_heads of size attn_dim // nb_heads
            # Output format is [batch, nb_heads, seq_len, attn_dim // nb_heads]
            query = self._split_in_heads(
                query)  # [bz, nb_heads, seq_len, head_sz]
            keys = self._split_in_heads(
                keys)  # [bz, nb_heads, seq_len, head_sz]
            values = self._split_in_heads(
                values)  # [bz, nb_heads, seq_len, head_sz]
            head_size = query.shape[-1].value

            # Stacking keys and values to get the present_attn
            present_attn = array_ops.stack(
                [keys, values], axis=1)  # [bz, 2, nb_heads, seq_len, head_sz]

            # Adding past_attn to keys and values
            past_keys, past_values = array_ops.unstack(
                past_attn, 2, axis=1)  # 2x [bz, nb_heads, past_len, head_sz]
            keys = array_ops.concat(
                [past_keys, keys],
                axis=-2)  # [bz, nb_heads, total_len, head_sz]
            values = array_ops.concat(
                [past_values, values],
                axis=-2)  # [bz, nb_heads, total_len, head_sz]

            # Performing multi-head attention
            attn_w = math_ops.matmul(
                query, keys,
                transpose_b=True)  # [bz. nb_heads, seq_len, total_len]
            attn_w = attn_w * gen_math_ops.rsqrt(
                math_ops.cast(head_size, attn_w.dtype) + 1e-8)
            attn_mask = self._mask_attn_weights(
                attn_w)  # [bz, 1, seq_len, total_len]
            attn_w = attn_w * attn_mask + math_ops.cast(
                1e-10, attn_w.dtype) * (1. - attn_mask)
            attn_w = nn_ops.softmax(attn_w)
            attn = math_ops.matmul(attn_w,
                                   values)  # [bz, nb_heads, seq_len, head_sz]

            # Merging attention heads, then 1d conv before returning
            out_attn = self._merge_heads(attn)  # [bz, seq_len, attn_dim]
            out_attn = self._conv1d(out_attn, attn_dim,
                                    'attn_fc2')  # [bz, seq_len, attn_dim]

            # Returning
            return out_attn, present_attn
コード例 #10
0
ファイル: transformer.py プロジェクト: zhanpengfang/research
    def _step(self, inputs, past_attns, time, feeder_cell, feeder_state):
        """ Performs the block operation on n-layers
            :param inputs: The tensor inputs (embedding of each word) - [batch, seq_len, emb_size]
            :param past_attns: The past attentions - [batch, nb_layers, 2, nb_heads. past_length, emb_size // nb_heads]
            :param time: A tensor representing the current time step
            :param feeder_cell: None or A feeder cell that returns a RNN cell output to use for conditioning
            :param feeder_state: None or the initial state of the feeder cell
            :param name: Name of the scope - To share weights between calls
            :return: A tuple consisting of:
                        1) The cell outputs - [batch, seq_len, emb_size]
                        2) The present attention - [batch, nb_layers, 2, nb_heads. seq_len, emb_size // nb_heads]
                        3) The new state of the feeder cell
        """
        with variable_scope.variable_scope(self._scope, default_name='step'):
            past_length = array_ops.shape(past_attns)[
                -2]  # How many past attention steps we have
            seq_len = array_ops.shape(inputs)[
                -2]  # How many steps are we computing for the current time
            emb_size = inputs.shape[-1].value  # The size of the embedding
            assert emb_size == self._emb_size, 'Expected an embedding size of %d' % self._emb_size

            # 1) Computing the word embedding of each token
            assert inputs.shape.ndims == 3, 'Expected [batch, seq_len, emb_size]'  # [bz, seq, emb]
            out_h = inputs

            # 2) Computing the position embedding of each token
            # If we know the context was padded, the effective past length is the context length + nb of time steps
            if self._past_seq_lengths is not None:
                past_length = gen_math_ops.minimum(
                    past_length,
                    self._past_seq_lengths + time)[:, None]  # [bz, 1]
            else:
                past_length = gen_array_ops.fill([self._batch_size, 1],
                                                 value=past_length)  # [bz, 1]
            step_ix = math_ops.range(seq_len)[None, :]  # [1, seq_len]
            token_positions = gen_math_ops.add(past_length,
                                               step_ix)  # [batch, seq_len]
            token_positions = gen_math_ops.minimum(
                self._position_emb_size - 1,
                token_positions)  # [batch, seq_len]
            h_pos = self._position_embedding_fn(
                token_positions)  # [bz, seq, emb]
            out_h = out_h + h_pos

            # 3) If we have a feeder cell, we also need to condition 'h' on it.
            next_feeder_state = feeder_state
            if feeder_cell is not None:
                assert feeder_state is not None, 'A feeder state is required if a feeder cell is provided.'
                assert inputs.shape[
                    1].value == 1, 'The seq dimension must be 1 to use a feeder_cell'
                feeder_outputs, next_feeder_state = feeder_cell(
                    array_ops.squeeze(inputs, axis=1), feeder_state)
                h_feed = feeder_outputs  # [bz, feeder_sz]
                if feeder_outputs.shape[-1].value != emb_size:
                    h_feed = core.Dense(emb_size,
                                        activation=None,
                                        name='h_feed')(h_feed)  # [bz, emb]
                h_feed = gen_array_ops.tile(h_feed[:, None, :],
                                            [1, seq_len, 1])  # [bz, seq, emb]
                out_h = out_h + h_feed

            # Transformer
            presents = []
            pasts = array_ops.unstack(
                past_attns,
                axis=1)  # list of [batch, 2, heads, past_len, head_sz]
            assert len(
                pasts
            ) == self._nb_layers, 'Expected the past attention to have %d layers.' % self._nb_layers

            for layer_ix, past_attn in enumerate(pasts):
                out_h, present = self._block(out_h, past_attn,
                                             'layer.%d' % layer_ix)
                presents += [present]
            presents = array_ops.stack(presents, axis=1)

            # Normalizing and returning
            cell_outputs = self._norm(out_h, 'norm_h')  # [batch, seq, emb]
            return cell_outputs, presents, next_feeder_state
コード例 #11
0
def dynamic_decode(decoder,
                   output_time_major=False,
                   impute_finished=False,
                   maximum_iterations=None,
                   parallel_iterations=32,
                   invariants_map=None,
                   swap_memory=False,
                   scope=None):
    """ Performs dynamic decoding with `decoder`.
        :param decoder: A `Decoder` instance.
        :param output_time_major: If True, outputs [time, batch, ...], otherwise outputs [batch, time, ...]
        :param impute_finished: If true, finished states are copied through the end of the game
        :param maximum_iterations: Int or None. The maximum number of steps (otherwise decode until it's done)
        :param parallel_iterations: Argument passed to tf.while_loop
        :param invariants_map: Optional. Dictionary of tensor path (in initial_state) to its shape invariant.
        :param swap_memory: Argument passed to `tf.while_loop`.
        :param scope: Optional variable scope to use.
        :return: A tuple of 1) final_outputs, 2) final_state, 3) final_sequence_length
    """
    if not isinstance(decoder, seq2seq.Decoder):
        raise TypeError('Expected decoder to be type Decoder, but saw: %s' %
                        type(decoder))

    with variable_scope.variable_scope(scope, 'decoder') as varscope:

        # Determine context types.
        ctxt = ops.get_default_graph()._get_control_flow_context()  # pylint: disable=protected-access
        is_xla = control_flow_util.GetContainingXLAContext(ctxt) is not None
        in_while_loop = control_flow_util.GetContainingWhileContext(
            ctxt) is not None

        # Properly cache variable values inside the while_loop.
        # Don't set a caching device when running in a loop, since it is possible that train steps could be wrapped
        # in a tf.while_loop. In that scenario caching prevents forward computations in loop iterations from re-reading
        # the updated weights.
        if not context.executing_eagerly() and not in_while_loop:
            if varscope.caching_device is None:
                varscope.set_caching_device(lambda op: op.device)

        # Setting maximum iterations
        if maximum_iterations is not None:
            maximum_iterations = ops.convert_to_tensor(
                maximum_iterations,
                dtype=dtypes.int32,
                name="maximum_iterations")
            if maximum_iterations.get_shape().ndims != 0:
                raise ValueError('maximum_iterations must be a scalar')

        def _inv_shape(maybe_ta):
            """ Returns the invariatns shape """
            if isinstance(maybe_ta, tensor_array_ops.TensorArray):
                return maybe_ta.flow.shape
            return maybe_ta.shape

        def _invariants(structure):
            """ Returns the invariants of a structure """
            return nest.map_structure(_inv_shape, structure)

        def _map_invariants(structure):
            """ Returns the invariants of a structure, but replaces the invariant using the value in invariants_map """
            return nest.map_structure_with_paths(
                lambda path, tensor:
                (invariants_map or {}).get(path, _inv_shape(tensor)),
                structure)

        # Initializing decoder
        initial_finished, initial_inputs, initial_state = decoder.initialize()
        zero_outputs = _create_zero_outputs(decoder.output_size,
                                            decoder.output_dtype,
                                            decoder.batch_size)

        if is_xla and maximum_iterations is None:
            raise ValueError(
                'maximum_iterations is required for XLA compilation.')
        if maximum_iterations is not None:
            initial_finished = gen_math_ops.logical_or(initial_finished,
                                                       maximum_iterations <= 0)
        initial_sequence_lengths = array_ops.zeros_like(initial_finished,
                                                        dtype=dtypes.int32)
        initial_time = constant_op.constant(0, dtype=dtypes.int32)

        # Creating initial output TA
        def _shape(batch_size, from_shape):
            """ Returns the batch_size concatenated with the from_shape """
            if (not isinstance(from_shape, tensor_shape.TensorShape)
                    or from_shape.ndims == 0):
                return tensor_shape.TensorShape(None)
            batch_size = tensor_util.constant_value(
                ops.convert_to_tensor(batch_size, name='batch_size'))
            return tensor_shape.TensorShape([batch_size
                                             ]).concatenate(from_shape)

        dynamic_size = maximum_iterations is None or not is_xla

        def _create_ta(shape, dtype):
            """ Creates a tensor array"""
            return tensor_array_ops.TensorArray(
                dtype=dtype,
                size=0 if dynamic_size else maximum_iterations,
                dynamic_size=dynamic_size,
                element_shape=_shape(decoder.batch_size, shape))

        initial_outputs_ta = nest.map_structure(_create_ta,
                                                decoder.output_size,
                                                decoder.output_dtype)

        def condition(unused_time, unused_outputs_ta, unused_state,
                      unused_inputs, finished, unused_sequence_lengths):
            """ While loop condition"""
            return gen_math_ops.logical_not(math_ops.reduce_all(finished))

        def body(time, outputs_ta, state, inputs, finished, sequence_lengths):
            """ Internal while_loop body. """
            (next_outputs, decoder_state, next_inputs,
             decoder_finished) = decoder.step(time, inputs, state)
            if decoder.tracks_own_finished:
                next_finished = decoder_finished
            else:
                next_finished = gen_math_ops.logical_or(
                    decoder_finished, finished)
            next_sequence_lengths = array_ops.where(
                gen_math_ops.logical_not(finished),
                gen_array_ops.fill(array_ops.shape(sequence_lengths),
                                   time + 1), sequence_lengths)

            nest.assert_same_structure(state, decoder_state)
            nest.assert_same_structure(outputs_ta, next_outputs)
            nest.assert_same_structure(inputs, next_inputs)

            # Zero out output values past finish
            if impute_finished:
                emit = nest.map_structure(
                    lambda out, zero: array_ops.where(finished, zero, out),
                    next_outputs, zero_outputs)
            else:
                emit = next_outputs

            # Copy through states past finish
            def _maybe_copy_state(new, cur):
                # TensorArrays, multiple dynamic dims, and scalar states get passed through.
                if isinstance(cur, tensor_array_ops.TensorArray):
                    pass_through = True
                elif None in new.shape.as_list()[1:]:
                    pass_through = True
                else:
                    new.set_shape(cur.shape)
                    pass_through = (new.shape.ndims == 0)
                return new if pass_through else array_ops.where(
                    finished, cur, new)

            if impute_finished:
                next_state = nest.map_structure(_maybe_copy_state,
                                                decoder_state, state)
            else:
                next_state = decoder_state

            outputs_ta = nest.map_structure(
                lambda ta, out: ta.write(time, out), outputs_ta, emit)
            return (time + 1, outputs_ta, next_state, next_inputs,
                    next_finished, next_sequence_lengths)

        res = control_flow_ops.while_loop(
            condition,
            body,
            loop_vars=(initial_time, initial_outputs_ta, initial_state,
                       initial_inputs, initial_finished,
                       initial_sequence_lengths),
            shape_invariants=(_invariants(initial_time),
                              _invariants(initial_outputs_ta),
                              _map_invariants(initial_state),
                              _invariants(initial_inputs),
                              _invariants(initial_finished),
                              _invariants(initial_sequence_lengths)),
            parallel_iterations=parallel_iterations,
            maximum_iterations=maximum_iterations,
            swap_memory=swap_memory)

        final_outputs_ta = res[1]
        final_state = res[2]
        final_sequence_lengths = res[5]

        final_outputs = nest.map_structure(lambda ta: ta.stack(),
                                           final_outputs_ta)

        try:
            final_outputs, final_state = decoder.finalize(
                final_outputs, final_state, final_sequence_lengths)
        except NotImplementedError:
            pass

        if not output_time_major:
            final_outputs = nest.map_structure(_transpose_batch_time,
                                               final_outputs)

    return final_outputs, final_state, final_sequence_lengths