def test_infer_greedy_with_context_without_memory(self):
        """Tests train_greedy with context
        """
        decoder = TransformerDecoder(
            vocab_size=self._vocab_size,
            output_layer=self._output_layer
        )
        helper = tx_helper.GreedyEmbeddingHelper(
            self._embedding_fn, self._start_tokens, self._end_token)

        outputs, length = decoder(
            memory=None,
            memory_sequence_length=None,
            memory_attention_bias=None,
            inputs=None,
            decoding_strategy='infer_greedy',
            helper=helper,
            context=self._context,
            context_sequence_length=self._context_length,
            end_token=self._end_token,
            max_decoding_length=self._max_decode_len,
            mode=tf.estimator.ModeKeys.PREDICT)
        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            outputs_ = sess.run(outputs)
            self.assertIsInstance(outputs_, TransformerDecoderOutput)
 def test_greedy_embedding_helper(self):
     """Tests with tf.contrib.seq2seq.GreedyEmbeddingHelper
     """
     decoder = TransformerDecoder(
         vocab_size=self._vocab_size,
         output_layer=self._output_layer
     )
     helper = tx_helper.GreedyEmbeddingHelper(
         self._embedding, self._start_tokens, self._end_token)
     outputs, length = decoder(
         memory=self._memory,
         memory_sequence_length=self._memory_sequence_length,
         memory_attention_bias=None,
         helper=helper,
         max_decoding_length=self._max_decode_len,
         mode=tf.estimator.ModeKeys.PREDICT)
     with self.test_session() as sess:
         sess.run(tf.global_variables_initializer())
         outputs_ = sess.run(outputs)
         self.assertIsInstance(outputs_, TransformerDecoderOutput)
Example #3
0
    def _build(self,
               decoding_strategy="train_greedy",
               initial_state=None,
               inputs=None,
               sequence_length=None,
               embedding=None,
               start_tokens=None,
               end_token=None,
               softmax_temperature=None,
               max_decoding_length=None,
               impute_finished=False,
               output_time_major=False,
               input_time_major=False,
               helper=None,
               mode=None,
               **kwargs):
        """Performs decoding. This is a shared interface for both
        :class:`~texar.modules.BasicRNNDecoder` and
        :class:`~texar.modules.AttentionRNNDecoder`.

        The function provides **3 ways** to specify the
        decoding method, with varying flexibility:

        1. The :attr:`decoding_strategy` argument: A string taking value of:

            - **"train_greedy"**: decoding in teacher-forcing fashion \
              (i.e., feeding \
              `ground truth` to decode the next step), and each sample is \
              obtained by taking the `argmax` of the RNN output logits. \
              Arguments :attr:`(inputs, sequence_length, input_time_major)` \
              are required for this strategy, and argument :attr:`embedding` \
              is optional.
            - **"infer_greedy"**: decoding in inference fashion (i.e., feeding \
              the `generated` sample to decode the next step), and each sample\
              is obtained by taking the `argmax` of the RNN output logits.\
              Arguments :attr:`(embedding, start_tokens, end_token)` are \
              required for this strategy, and argument \
              :attr:`max_decoding_length` is optional.
            - **"infer_sample"**: decoding in inference fashion, and each
              sample is obtained by `random sampling` from the RNN output
              distribution. Arguments \
              :attr:`(embedding, start_tokens, end_token)` are \
              required for this strategy, and argument \
              :attr:`max_decoding_length` is optional.

          This argument is used only when argument :attr:`helper` is `None`.

          Example:

            .. code-block:: python

                embedder = WordEmbedder(vocab_size=data.vocab.size)
                decoder = BasicRNNDecoder(vocab_size=data.vocab.size)

                # Teacher-forcing decoding
                outputs_1, _, _ = decoder(
                    decoding_strategy='train_greedy',
                    inputs=embedder(data_batch['text_ids']),
                    sequence_length=data_batch['length']-1)

                # Random sample decoding. Gets 100 sequence samples
                outputs_2, _, sequence_length = decoder(
                    decoding_strategy='infer_sample',
                    start_tokens=[data.vocab.bos_token_id]*100,
                    end_token=data.vocab.eos.token_id,
                    embedding=embedder,
                    max_decoding_length=60)

        2. The :attr:`helper` argument: An instance of subclass of \
           :class:`texar.modules.Helper`. This
           provides a superset of decoding strategies than above, for example:

            - :class:`~texar.modules.TrainingHelper` corresponding to the \
              "train_greedy" strategy.
            - :class:`~texar.modules.GreedyEmbeddingHelper` and \
              :class:`~texar.modules.SampleEmbeddingHelper` corresponding to \
              the "infer_greedy" and "infer_sample", respectively.
            - :class:`~texar.modules.TopKSampleEmbeddingHelper` for Top-K \
              sample decoding.
            - :class:`ScheduledEmbeddingTrainingHelper` and \
              :class:`ScheduledOutputTrainingHelper` for scheduled \
              sampling.
            - :class:`~texar.modules.SoftmaxEmbeddingHelper` and \
              :class:`~texar.modules.GumbelSoftmaxEmbeddingHelper` for \
              soft decoding and gradient backpropagation.

          Helpers give the maximal flexibility of configuring the decoding\
          strategy.

          Example:

            .. code-block:: python

                embedder = WordEmbedder(vocab_size=data.vocab.size)
                decoder = BasicRNNDecoder(vocab_size=data.vocab.size)

                # Teacher-forcing decoding, same as above with
                # `decoding_strategy='train_greedy'`
                helper_1 = texar.modules.TrainingHelper(
                    inputs=embedders(data_batch['text_ids']),
                    sequence_length=data_batch['length']-1)
                outputs_1, _, _ = decoder(helper=helper_1)

                # Gumbel-softmax decoding
                helper_2 = GumbelSoftmaxEmbeddingHelper(
                    embedding=embedder,
                    start_tokens=[data.vocab.bos_token_id]*100,
                    end_token=data.vocab.eos_token_id,
                    tau=0.1)
                outputs_2, _, sequence_length = decoder(
                    max_decoding_length=60, helper=helper_2)

        3. :attr:`hparams["helper_train"]` and :attr:`hparams["helper_infer"]`:\
           Specifying the helper through hyperparameters. Train and infer \
           strategy is toggled based on :attr:`mode`. Appriopriate arguments \
           (e.g., :attr:`inputs`, :attr:`start_tokens`, etc) are selected to \
           construct the helper. Additional arguments for helper constructor \
           can be provided either through :attr:`**kwargs`, or through \
           :attr:`hparams["helper_train/infer"]["kwargs"]`.

           This means is used only when both :attr:`decoding_strategy` and \
           :attr:`helper` are `None`.

           Example:

             .. code-block:: python

                h = {
                    "helper_infer": {
                        "type": "GumbelSoftmaxEmbeddingHelper",
                        "kwargs": { "tau": 0.1 }
                    }
                }
                embedder = WordEmbedder(vocab_size=data.vocab.size)
                decoder = BasicRNNDecoder(vocab_size=data.vocab.size, hparams=h)

                # Gumbel-softmax decoding
                output, _, _ = decoder(
                    decoding_strategy=None, # Sets to None explicit
                    embedding=embedder,
                    start_tokens=[data.vocab.bos_token_id]*100,
                    end_token=data.vocab.eos_token_id,
                    max_decoding_length=60,
                    mode=tf.estimator.ModeKeys.PREDICT)
                        # PREDICT mode also shuts down dropout

        Args:
            decoding_strategy (str): A string specifying the decoding
                strategy. Different arguments are required based on the
                strategy.
                Ignored if :attr:`helper` is given.
            initial_state (optional): Initial state of decoding.
                If `None` (default), zero state is used.

            inputs (optional): Input tensors for teacher forcing decoding.
                Used when `decoding_strategy` is set to "train_greedy", or
                when `hparams`-configured helper is used.

                - If :attr:`embedding` is `None`, `inputs` is directly \
                fed to the decoder. E.g., in `"train_greedy"` strategy, \
                `inputs` must be a 3D Tensor of shape \
                `[batch_size, max_time, emb_dim]` (or \
                `[max_time, batch_size, emb_dim]` if `input_time_major`==True).
                - If `embedding` is given, `inputs` is used as index \
                to look up embeddings and feed in the decoder. \
                E.g., if `embedding` is an instance of \
                :class:`~texar.modules.WordEmbedder`, \
                then :attr:`inputs` is usually a 2D int Tensor \
                `[batch_size, max_time]` (or \
                `[max_time, batch_size]` if `input_time_major`==True) \
                containing the token indexes.
            sequence_length (optional): A 1D int Tensor containing the
                sequence length of :attr:`inputs`.
                Used when `decoding_strategy="train_greedy"` or
                `hparams`-configured helper is used.
            embedding (optional): Embedding used when:

                - "infer_greedy" or "infer_sample" `decoding_strategy` is \
                used. This can be a callable or the `params` argument for \
                :tf_main:`embedding_lookup <nn/embedding_lookup>`. \
                If a callable, it can take a vector tensor of token `ids`, \
                or take two arguments (`ids`, `times`), where `ids` \
                is a vector tensor of token ids, and `times` is a vector tensor\
                of time steps (i.e., position ids). The latter case can be used\
                when attr:`embedding` is a combination of word embedding and\
                position embedding. `embedding` is required in this case.
                - "train_greedy" `decoding_strategy` is used.\
                This can be a callable or the `params` argument for \
                :tf_main:`embedding_lookup <nn/embedding_lookup>`. \
                If a callable, it can take :attr:`inputs` and returns \
                the input embedding. `embedding` is optional in this case.
            start_tokens (optional): A int Tensor of shape `[batch_size]`,
                the start tokens. Used when `decoding_strategy="infer_greedy"`
                or `"infer_sample"`, or when the helper specified in `hparams`
                is used.

                Example:

                    .. code-block:: python

                        data = tx.data.MonoTextData(hparams)
                        iterator = DataIterator(data)
                        batch = iterator.get_next()

                        bos_token_id = data.vocab.bos_token_id
                        start_tokens=tf.ones_like(batch['length'])*bos_token_id

            end_token (optional): A int 0D Tensor, the token that marks end
                of decoding.
                Used when `decoding_strategy="infer_greedy"` or
                `"infer_sample"`, or when the helper specified in `hparams`
                is used.
            softmax_temperature (optional): A float 0D Tensor, value to divide
                the logits by before computing the softmax. Larger values
                (above 1.0) result in more random samples. Must > 0. If `None`,
                1.0 is used.
                Used when `decoding_strategy="infer_sample"`.
            max_decoding_length: A int scalar Tensor indicating the maximum
                allowed number of decoding steps. If `None` (default), either
                `hparams["max_decoding_length_train"]` or
                `hparams["max_decoding_length_infer"]` is used
                according to :attr:`mode`.
            impute_finished (bool): If `True`, then states for batch
                entries which are marked as finished get copied through and
                the corresponding outputs get zeroed out.  This causes some
                slowdown at each time step, but ensures that the final state
                and outputs have the correct values and that backprop ignores
                time steps that were marked as finished.
            output_time_major (bool): If `True`, outputs are returned as
                time major tensors. If `False` (default), outputs are returned
                as batch major tensors.
            input_time_major (optional): Whether the :attr:`inputs` tensor is
                time major.
                Used when `decoding_strategy="train_greedy"` or
                `hparams`-configured helper is used.
            helper (optional): An instance of
                :class:`texar.modules.Helper`
                that defines the decoding strategy. If given,
                `decoding_strategy`
                and helper configs in :attr:`hparams` are ignored.
            mode (str, optional): A string taking value in
                :tf_main:`tf.estimator.ModeKeys <estimator/ModeKeys>`. If
                `TRAIN`, training related hyperparameters are used (e.g.,
                `hparams['max_decoding_length_train']`), otherwise,
                inference related hyperparameters are used (e.g.,
                `hparams['max_decoding_length_infer']`).
                If `None` (default), `TRAIN` mode is used.
            **kwargs: Other keyword arguments for constructing helpers
                defined by `hparams["helper_trainn"]` or
                `hparams["helper_infer"]`.

        Returns:
            `(outputs, final_state, sequence_lengths)`, where

            - **`outputs`**: an object containing the decoder output on all \
            time steps.
            - **`final_state`**: is the cell state of the final time step.
            - **`sequence_lengths`**: is an int Tensor of shape `[batch_size]` \
            containing the length of each sample.
        """
        # Helper
        if helper is not None:
            pass
        elif decoding_strategy is not None:
            if decoding_strategy == "train_greedy":
                helper = rnn_decoder_helpers._get_training_helper(
                    inputs, sequence_length, embedding, input_time_major)
            elif decoding_strategy == "infer_greedy":
                helper = tx_helper.GreedyEmbeddingHelper(
                    embedding, start_tokens, end_token)
            elif decoding_strategy == "infer_sample":
                helper = tx_helper.SampleEmbeddingHelper(
                    embedding, start_tokens, end_token, softmax_temperature)
            else:
                raise ValueError(
                    "Unknown decoding strategy: {}".format(decoding_strategy))
        else:
            if is_train_mode_py(mode):
                kwargs_ = copy.copy(self._hparams.helper_train.kwargs.todict())
                helper_type = self._hparams.helper_train.type
            else:
                kwargs_ = copy.copy(self._hparams.helper_infer.kwargs.todict())
                helper_type = self._hparams.helper_infer.type
            kwargs_.update({
                "inputs": inputs,
                "sequence_length": sequence_length,
                "time_major": input_time_major,
                "embedding": embedding,
                "start_tokens": start_tokens,
                "end_token": end_token,
                "softmax_temperature": softmax_temperature
            })
            kwargs_.update(kwargs)
            helper = rnn_decoder_helpers.get_helper(helper_type, **kwargs_)
        self._helper = helper

        # Initial state
        if initial_state is not None:
            self._initial_state = initial_state
        else:
            self._initial_state = self.zero_state(batch_size=self.batch_size,
                                                  dtype=tf.float32)

        # Maximum decoding length
        max_l = max_decoding_length
        if max_l is None:
            max_l_train = self._hparams.max_decoding_length_train
            if max_l_train is None:
                max_l_train = utils.MAX_SEQ_LENGTH
            max_l_infer = self._hparams.max_decoding_length_infer
            if max_l_infer is None:
                max_l_infer = utils.MAX_SEQ_LENGTH
            max_l = tf.cond(is_train_mode(mode), lambda: max_l_train,
                            lambda: max_l_infer)
        self.max_decoding_length = max_l
        # Decode
        outputs, final_state, sequence_lengths = dynamic_decode(
            decoder=self,
            impute_finished=impute_finished,
            maximum_iterations=max_l,
            output_time_major=output_time_major)

        if not self._built:
            self._add_internal_trainable_variables()
            # Add trainable variables of `self._cell` which may be
            # constructed externally.
            self._add_trainable_variable(
                layers.get_rnn_cell_trainable_variables(self._cell))
            if isinstance(self._output_layer, tf.layers.Layer):
                self._add_trainable_variable(
                    self._output_layer.trainable_variables)
            # Add trainable variables of `self._beam_search_rnn_cell` which
            # may already be constructed and used.
            if self._beam_search_cell is not None:
                self._add_trainable_variable(
                    self._beam_search_cell.trainable_variables)

            self._built = True

        return outputs, final_state, sequence_lengths
Example #4
0
    def _build(
            self,  # pylint: disable=arguments-differ, too-many-statements
            decoding_strategy='train_greedy',
            inputs=None,
            memory=None,
            memory_sequence_length=None,
            memory_attention_bias=None,
            beam_width=None,
            length_penalty=0.,
            start_tokens=None,
            end_token=None,
            context=None,
            context_sequence_length=None,
            softmax_temperature=None,
            max_decoding_length=None,
            impute_finished=False,
            embedding=None,
            helper=None,
            mode=None):
        """Performs decoding.

        The interface is mostly the same with that of RNN decoders
        (see :meth:`~texar.modules.RNNDecoderBase._build`). The main difference
        is that, here, `sequence_length` is not needed, and continuation
        generation is additionally supported.

        The function provides **3 ways** to specify the decoding method, with
        varying flexibility:

        1. The :attr:`decoding_strategy` argument.

            - **"train_greedy"**: decoding in teacher-forcing fashion (i.e.,
              feeding ground truth to decode the next step), and for each step
              sample is obtained by taking the `argmax` of logits.
              Argument :attr:`inputs` is required for this strategy.
            - **"infer_greedy"**: decoding in inference fashion (i.e., feeding
              `generated` sample to decode the next step), and for each step
              sample is obtained by taking the `argmax` of logits.
              Arguments :attr:`(start_tokens, end_token)` are
              required for this strategy, and argument
              :attr:`max_decoding_length` is optional.
            - **"infer_sample"**: decoding in inference fashion, and for each
              step sample is obtained by `random sampling` from the logits.
              Arguments :attr:`(start_tokens, end_token)` are required for this
              strategy, and argument :attr:`max_decoding_length` is optional.

          This argument is used only when arguments :attr:`helper` and
          :attr:`beam_width` are both `None`.

        2. The :attr:`helper` argument: An instance of subclass of
           :class:`texar.modules.Helper`.
           This provides a superset of decoding strategies than above.
           The interface is the same as in RNN decoders.
           Please refer to :meth:`texar.modules.RNNDecoderBase._build` for
           detailed usage and examples.

           Note that, here, though using a
           :class:`~texar.modules.TrainingHelper` corresponds to the
           "train_greedy" strategy above and will get the same output results,
           the implementation is *slower* than
           directly setting `decoding_strategy="train_greedy"`.

           Argument :attr:`max_decoding_length` is optional.

        3. **Beam search**: set :attr:`beam_width` to use beam search decoding.
           Arguments :attr:`(start_tokens, end_token)` are required,
           and argument :attr:`max_decoding_length` is optional.

        Args:
            memory (optional): The memory to attend, e.g., the output of an RNN
                encoder. A Tensor of shape `[batch_size, memory_max_time, dim]`.
            memory_sequence_length (optional): A Tensor of shape `[batch_size]`
                containing the sequence lengths for the batch entries in
                memory. Used to create attention bias of
                :attr:`memory_attention_bias` is not given. Ignored if
                `memory_attention_bias` is provided.
            memory_attention_bias (optional): A Tensor of shape
                `[batch_size, num_heads, memory_max_time, dim]`.
                An attention bias typically sets the value of a padding
                position to a large negative value for masking. If not given,
                :attr:`memory_sequence_length` is used to automatically
                create an attention bias.
            inputs (optional): Input tensor for teacher forcing decoding, of
                shape `[batch_size, target_max_time, emb_dim]` containing the
                target sequence word embeddings.
                Used when :attr:`decoding_strategy` is set to "train_greedy".
            decoding_strategy (str): A string specifying the decoding
                strategy, including "train_greedy", "infer_greedy",
                "infer_sample".
                Different arguments are required based on the
                strategy. See above for details. Ignored if
                :attr:`beam_width` or :attr:`helper` is set.
            beam_width (int): Set to use beam search. If given,
                :attr:`decoding_strategy` is ignored.
            length_penalty (float): Length penalty coefficient used in beam
                search decoding. Refer to https://arxiv.org/abs/1609.08144
                for more details.
                It Should be larger if longer sentences are wanted.
            start_tokens (optional): An int Tensor of shape `[batch_size]`,
                containing the start tokens.
                Used when :attr:`decoding_strategy` = "infer_greedy" or
                "infer_sample", or :attr:`beam_width` is set.
                Ignored when context is set.
            end_token (optional): An int 0D Tensor, the token that marks end
                of decoding.
                Used when :attr:`decoding_strategy` = "infer_greedy" or
                "infer_sample", or :attr:`beam_width` is set.
            context (optional): An int Tensor of shape `[batch_size, length]`,
                containing the starting tokens for decoding.
                If context is set, the start_tokens will be ignored.
            context_sequence_length (optional): specify the length of context.
            softmax_temperature (optional): A float 0D Tensor, value to divide
                the logits by before computing the softmax. Larger values
                (above 1.0) result in more random samples. Must > 0. If `None`,
                1.0 is used.
                Used when :attr:`decoding_strategy` = "infer_sample"`.
            max_decoding_length (optional): An int scalar Tensor indicating
                the maximum allowed number of decoding steps.
                If `None` (default), use "max_decoding_length" defined in
                :attr:`hparams`. Ignored in "train_greedy" decoding.
            impute_finished (bool): If `True`, then states for batch
                entries which are marked as finished get copied through and
                the corresponding outputs get zeroed out.  This causes some
                slowdown at each time step, but ensures that the final state
                and outputs have the correct values and that backprop ignores
                time steps that were marked as finished. Ignored in
                "train_greedy" decoding.
            embedding (optional): Embedding used when
                "infer_greedy" or "infer_sample" `decoding_strategy`, or
                beam search, is used. This can be
                a callable or the `params` argument for
                :tf_main:`embedding_lookup <nn/embedding_lookup>`.
                If a callable, it can take a vector tensor of token `ids`,
                or take two arguments (`ids`, `times`), where `ids`
                is a vector tensor of token ids, and `times` is a vector tensor
                of time steps (i.e., position ids). The latter case can be used
                when attr:`embedding` is a combination of word embedding and
                position embedding.
            helper (optional): An instance of
                :tf_main:`Helper <contrib/seq2seq/Helper>` that defines the
                decoding strategy. If given, :attr:`decoding_strategy` is
                ignored.
            mode (optional): A tensor taking value in
                :tf_main:`tf.estimator.ModeKeys <estimator/ModeKeys>`, including
                `TRAIN`, `EVAL`, and `PREDICT`. Controls dropout mode.
                If `None` (default), :func:`texar.global_mode`
                is used.

        Returns:

            - For **"train_greedy"** decoding, returns an instance of \
            :class:`~texar.modules.TransformerDecoderOutput` which contains\
            `sample_id` and `logits`.

            - For **"infer_greedy"** and **"infer_sample"** decoding or\
            decoding with :attr:`helper`, returns\
            a tuple `(outputs, sequence_lengths)`, where `outputs` is an \
            instance of :class:`~texar.modules.TransformerDecoderOutput` as\
            in "train_greedy", and `sequence_lengths` is a Tensor of shape\
            `[batch_size]` containing the length of each sample.

            - For **beam search** decoding, returns a `dict` containing keys\
            "sample_id" and "log_prob".

                - **"sample_id"** is an int Tensor of shape \
                `[batch_size, max_time, beam_width]` containing generated\
                token indexes. `sample_id[:,:,0]` is the highest-probable \
                sample.
                - **"log_prob"** is a float Tensor of shape \
                `[batch_size, beam_width]` containing the log probability \
                of each sequence sample.
        """

        if memory is not None:
            if memory_attention_bias is None:
                if memory_sequence_length is None:
                    raise ValueError("`memory_sequence_length` is required if "
                                     "`memory_attention_bias` is not given.")

                enc_padding = 1 - tf.sequence_mask(memory_sequence_length,
                                                   shape_list(memory)[1],
                                                   dtype=tf.float32)
                memory_attention_bias = attn.attention_bias_ignore_padding(
                    enc_padding)

        # record the context, which will be used in step function
        # for dynamic_decode
        if context is not None:
            start_tokens = context[:, 0]
            self.context = context[:, 1:]
            self.context_sequence_length = context_sequence_length - 1
        else:
            self.context = None

        self.embedding = embedding

        if helper is None and beam_width is None and \
                decoding_strategy == 'train_greedy':  # Teacher-forcing

            decoder_self_attention_bias = (attn.attention_bias_lower_triangle(
                shape_list(inputs)[1]))

            decoder_output = self._self_attention_stack(
                inputs,
                memory,
                decoder_self_attention_bias=decoder_self_attention_bias,
                memory_attention_bias=memory_attention_bias,
                cache=None,
                mode=mode)
            logits = self._output_layer(decoder_output)
            preds = tf.cast(tf.argmax(logits, axis=-1), tf.int32)
            rets = TransformerDecoderOutput(logits=logits, sample_id=preds)

        else:
            if max_decoding_length is None:
                max_decoding_length = self._hparams.max_decoding_length
            self.max_decoding_length = max_decoding_length
            if beam_width is None:  # Inference-like decoding
                # Prepare helper
                if helper is None:
                    if decoding_strategy == "infer_greedy":
                        helper = tx_helper.GreedyEmbeddingHelper(
                            embedding, start_tokens, end_token)
                    elif decoding_strategy == "infer_sample":
                        helper = tx_helper.SampleEmbeddingHelper(
                            embedding, start_tokens, end_token,
                            softmax_temperature)
                    else:
                        raise ValueError(
                            "Unknown decoding strategy: {}".format(
                                decoding_strategy))
                self._helper = helper

                self._cache = self._init_cache(memory,
                                               memory_attention_bias,
                                               beam_search_decoding=False)
                if context is not None:
                    self.context = tf.pad(self.context, [[
                        0, 0
                    ], [0, max_decoding_length - shape_list(self.context)[1]]])

                outputs, _, sequence_lengths = dynamic_decode(
                    decoder=self,
                    impute_finished=impute_finished,
                    maximum_iterations=max_decoding_length,
                    output_time_major=False,
                    scope=self.variable_scope)

                if context is not None:
                    # Here the length of sample_id will be larger than that
                    # of logit by 1, because there will be a additional
                    # start_token in the returned sample_id.
                    # the start_id should be the first token of the
                    # given context
                    outputs = TransformerDecoderOutput(
                        logits=outputs.logits,
                        sample_id=tf.concat([
                            tf.expand_dims(start_tokens, 1), outputs.sample_id
                        ],
                                            axis=1))
                    sequence_lengths = sequence_lengths + 1
                rets = outputs, sequence_lengths

            else:  # Beam-search decoding
                # Ignore `decoding_strategy`; Assume `helper` is not set
                if helper is not None:
                    raise ValueError("Must not set 'beam_width' and 'helper' "
                                     "simultaneously.")
                _batch_size = shape_list(start_tokens)[0]
                self._cache = self._init_cache(memory,
                                               memory_attention_bias,
                                               beam_search_decoding=True,
                                               batch_size=_batch_size)

                # The output format is different when running beam search
                sample_id, log_prob = self._beam_decode(
                    start_tokens,
                    end_token,
                    beam_width=beam_width,
                    length_penalty=length_penalty,
                    decode_length=max_decoding_length,
                )
                rets = {'sample_id': sample_id, 'log_prob': log_prob}

        if not self._built:
            self._add_internal_trainable_variables()
            self._built = True

        return rets
    def _build(
            self,  # pylint: disable=arguments-differ, too-many-statements
            decoding_strategy='train_greedy',
            inputs=None,
            adjs=None,
            memory=None,
            memory_sequence_length=None,
            memory_attention_bias=None,
            beam_width=None,
            length_penalty=0.,
            start_tokens=None,
            end_token=None,
            context=None,
            context_sequence_length=None,
            softmax_temperature=None,
            max_decoding_length=None,
            impute_finished=False,
            embedding=None,
            helper=None,
            mode=None):
        """Performs decoding.

        See 'Texar.modules.decoders.transformer_decoders.TransformerDecoder' for details

        adjs: A 3D Tensor of shape `[batch_size, max_time, max_time]`,
                containing the adjacency matrices of input sequences
        """
        # Get adjacency masks from adjs
        self.adj_masks = 1 - tf.cast(tf.equal(adjs, 0), dtype=tf.float32)

        if memory is not None:
            if memory_attention_bias is None:
                if memory_sequence_length is None:
                    raise ValueError("`memory_sequence_length` is required if "
                                     "`memory_attention_bias` is not given.")

                enc_padding = 1 - tf.sequence_mask(memory_sequence_length,
                                                   shape_list(memory)[1],
                                                   dtype=tf.float32)
                memory_attention_bias = attn.attention_bias_ignore_padding(
                    enc_padding)

        # record the context, which will be used in step function
        # for dynamic_decode
        if context is not None:
            start_tokens = context[:, 0]
            self.context = context[:, 1:]
            self.context_sequence_length = context_sequence_length - 1
        else:
            self.context = None

        self.embedding = embedding

        if helper is None and beam_width is None and \
                decoding_strategy == 'train_greedy':  # Teacher-forcing

            decoder_self_attention_bias = (attn.attention_bias_lower_triangle(
                shape_list(inputs)[1]))

            decoder_output = self._self_attention_stack(
                inputs,
                memory,
                decoder_self_attention_bias=decoder_self_attention_bias,
                memory_attention_bias=memory_attention_bias,
                cache=None,
                mode=mode)
            logits = self._output_layer(decoder_output)
            preds = tf.to_int32(tf.argmax(logits, axis=-1))
            rets = TransformerDecoderOutput(logits=logits, sample_id=preds)

        else:
            if max_decoding_length is None:
                max_decoding_length = self._hparams.max_decoding_length
            self.max_decoding_length = max_decoding_length
            if beam_width is None:  # Inference-like decoding
                # Prepare helper
                if helper is None:
                    if decoding_strategy == "infer_greedy":
                        helper = tx_helper.GreedyEmbeddingHelper(
                            embedding, start_tokens, end_token)
                    elif decoding_strategy == "infer_sample":
                        helper = tx_helper.SampleEmbeddingHelper(
                            embedding, start_tokens, end_token,
                            softmax_temperature)
                    else:
                        raise ValueError(
                            "Unknown decoding strategy: {}".format(
                                decoding_strategy))
                self._helper = helper

                self._cache = self._init_cache(memory,
                                               memory_attention_bias,
                                               beam_search_decoding=False)
                if context is not None:
                    self.context = tf.pad(self.context, [[
                        0, 0
                    ], [0, max_decoding_length - shape_list(self.context)[1]]])

                outputs, _, sequence_lengths = dynamic_decode(
                    decoder=self,
                    impute_finished=impute_finished,
                    maximum_iterations=max_decoding_length,
                    output_time_major=False,
                    scope=self.variable_scope)

                if context is not None:
                    # Here the length of sample_id will be larger than that
                    # of logit by 1, because there will be a additional
                    # start_token in the returned sample_id.
                    # the start_id should be the first token of the
                    # given context
                    outputs = TransformerDecoderOutput(
                        logits=outputs.logits,
                        sample_id=tf.concat([
                            tf.expand_dims(start_tokens, 1), outputs.sample_id
                        ],
                                            axis=1))
                    sequence_lengths = sequence_lengths + 1
                rets = outputs, sequence_lengths

            else:  # Beam-search decoding
                # Ignore `decoding_strategy`; Assume `helper` is not set
                if helper is not None:
                    raise ValueError("Must not set 'beam_width' and 'helper' "
                                     "simultaneously.")
                _batch_size = shape_list(start_tokens)[0]
                self._cache = self._init_cache(memory,
                                               memory_attention_bias,
                                               beam_search_decoding=True,
                                               batch_size=_batch_size)

                # The output format is different when running beam search
                sample_id, log_prob = self._beam_decode(
                    start_tokens,
                    end_token,
                    beam_width=beam_width,
                    length_penalty=length_penalty,
                    decode_length=max_decoding_length,
                )
                rets = {'sample_id': sample_id, 'log_prob': log_prob}

        if not self._built:
            self._add_internal_trainable_variables()
            self._built = True

        return rets
    def _build(self,
               decoding_strategy="train_greedy",
               initial_state=None,
               inputs=None,
               memory=None,
               sequence_length=None,
               embedding=None,
               start_tokens=None,
               end_token=None,
               softmax_temperature=None,
               max_decoding_length=None,
               impute_finished=False,
               output_time_major=False,
               input_time_major=False,
               helper=None,
               mode=None,
               **kwargs):
        # Memory
        for _mechanism in self._cell._attention_mechanisms:
            _mechanism.initialize_memory(memory)
        # Helper
        if helper is not None:
            pass
        elif decoding_strategy is not None:
            if decoding_strategy == "train_greedy":
                helper = rnn_decoder_helpers._get_training_helper(
                    inputs, sequence_length, embedding, input_time_major)
            elif decoding_strategy == "infer_greedy":
                helper = tx_helper.GreedyEmbeddingHelper(
                    embedding, start_tokens, end_token)
            elif decoding_strategy == "infer_sample":
                helper = tx_helper.SampleEmbeddingHelper(
                    embedding, start_tokens, end_token, softmax_temperature)
            else:
                raise ValueError(
                    "Unknown decoding strategy: {}".format(decoding_strategy))
        else:
            if is_train_mode_py(mode):
                kwargs_ = copy.copy(self._hparams.helper_train.kwargs.todict())
                helper_type = self._hparams.helper_train.type
            else:
                kwargs_ = copy.copy(self._hparams.helper_infer.kwargs.todict())
                helper_type = self._hparams.helper_infer.type
            kwargs_.update({
                "inputs": inputs,
                "sequence_length": sequence_length,
                "time_major": input_time_major,
                "embedding": embedding,
                "start_tokens": start_tokens,
                "end_token": end_token,
                "softmax_temperature": softmax_temperature
            })
            kwargs_.update(kwargs)
            helper = rnn_decoder_helpers.get_helper(helper_type, **kwargs_)
        self._helper = helper

        # Initial state
        if initial_state is not None:
            self._initial_state = initial_state
        else:
            self._initial_state = self.zero_state(batch_size=self.batch_size,
                                                  dtype=tf.float32)

        # Maximum decoding length
        max_l = max_decoding_length
        if max_l is None:
            max_l_train = self._hparams.max_decoding_length_train
            if max_l_train is None:
                max_l_train = utils.MAX_SEQ_LENGTH
            max_l_infer = self._hparams.max_decoding_length_infer
            if max_l_infer is None:
                max_l_infer = utils.MAX_SEQ_LENGTH
            max_l = tf.cond(is_train_mode(mode), lambda: max_l_train,
                            lambda: max_l_infer)
        self.max_decoding_length = max_l
        # Decode
        outputs, final_state, sequence_lengths = dynamic_decode(
            decoder=self,
            impute_finished=impute_finished,
            maximum_iterations=max_l,
            output_time_major=output_time_major)

        if not self._built:
            self._add_internal_trainable_variables()
            # Add trainable variables of `self._cell` which may be
            # constructed externally.
            self._add_trainable_variable(
                layers.get_rnn_cell_trainable_variables(self._cell))
            if isinstance(self._output_layer, tf.layers.Layer):
                self._add_trainable_variable(
                    self._output_layer.trainable_variables)
            # Add trainable variables of `self._beam_search_rnn_cell` which
            # may already be constructed and used.
            if self._beam_search_cell is not None:
                self._add_trainable_variable(
                    self._beam_search_cell.trainable_variables)

            self._built = True

        return outputs, final_state, sequence_lengths