def _build(self,    # pylint: disable=arguments-differ
               memory,
               memory_sequence_length=None,
               memory_attention_bias=None,
               inputs=None,
               sequence_length=None,
               decoding_strategy='train_greedy',
               beam_width=1,
               alpha=0,
               start_tokens=None,
               end_token=None,
               max_decoding_length=None,
               mode=None):
        """Performs decoding.

        The decoder supports 4 decoding strategies. For the first 3 strategies,
        set :attr:`decoding_strategy` to the respective string.

        - **"train_greedy"**: decoding in teacher-forcing fashion \
          (i.e., feeding \
          ground truth to decode the next step), and for each step sample \
          is obtained by taking the `argmax` of logits. \
          Argument :attr:`inputs` is required for this strategy. \
          :attr:`sequence_length` is optional.
        - **"infer_greedy"**: decoding in inference fashion (i.e., feeding \
          `generated` sample to decode the next step), and for each
          step sample is obtained by taking the `argmax` of logits.\
          Arguments :attr:`(start_tokens, end_token)` are \
          required for this strategy, and argument \
          :attr:`max_decoding_length` is optional.
        - **"infer_sample"**: decoding in inference fashion, and for each step\
          sample is obtained by `random sampling` from the logits.
          Arguments :attr:`(start_tokens, end_token)` are \
          required for this strategy, and argument \
          :attr:`max_decoding_length` is optional.
        - **Beam Search**: set :attr:`beam_width` to > 1 to use beam search \
          decoding.\
          Arguments :attr:`(start_tokens, end_token)` are \
          required, and argument \
          :attr:`max_decoding_length` is optional.

        Args:
            memory: The memory to attend, e.g., the output of an RNN encoder.
                A Tensor of shape `[batch_size, memory_max_time, dim]`.
            memory_sequence_length (optional): A Tensor of shape `[batch_size]`
                containing the sequence lengths for the batch entries in
                memory. Used to create attention bias of
                :attr:`memory_attention_bias` is not given. Ignored if
                `memory_attention_bias` is provided.
            memory_attention_bias (optional): A Tensor of shape
                `[batch_size, num_heads, memory_max_time, dim]`.
                An attention bias typically sets the value of a padding
                position to a large negative value for masking. If not given,
                :attr:`memory_sequence_length` is used to automatically
                create an attention bias.
            inputs (optional): Input tensor for teacher forcing decoding, of
                shape `[batch_size, target_max_time, emb_dim]` containing the
                target sequence word embeddings.
                Used when :attr:`decoding_strategy` is set to "train_greedy".
            sequence_length (optional): A Tensor of shape `[batch_size]`,
                containing the sequence length of :attr:`inputs`.
                Tokens beyond the respective sequence length are masked out.
                Used when :attr:`decoding_strategy` is set to
                "train_greedy".
            decoding_strategy (str): A string specifying the decoding
                strategy, including "train_greedy", "infer_greedy",
                "infer_sample".
                Different arguments are required based on the
                strategy. See above for details. Ignored if
                :attr:`beam_width` > 1.
            beam_width (int): Set to > 1 to use beam search.
            alpha (float): Length penalty coefficient.
                Refer to https://arxiv.org/abs/1609.08144
                for more details.
            tart_tokens (optional): An int Tensor of shape `[batch_size]`,
                containing the start tokens.
                Used when `decoding_strategy` = "infer_greedy" or
                "infer_sample", or `beam_width` > 1.
            end_token (optional): An int 0D Tensor, the token that marks end
                of decoding.
                Used when `decoding_strategy` = "infer_greedy" or
                "infer_sample", or `beam_width` > 1.
            max_decoding_length (optional): An int scalar Tensor indicating
                the maximum allowed number of decoding steps.
                If `None` (default), use "max_decoding_length" defined in
                :attr:`hparams`. Ignored in "train_greedy" decoding.
            mode (optional): A tensor taking value in
                :tf_main:`tf.estimator.ModeKeys <estimator/ModeKeys>`, including
                `TRAIN`, `EVAL`, and `PREDICT`. Controls dropout mode.
                If `None` (default), :func:`texar.global_mode`
                is used.

        Returns:

            - For **"train_greedy"** decoding, returns an instance of \
            :class:`~texar.modules.TransformerDecoderOutput` which contains\
            `sample_id` and `logits`.

            - For **"infer_greedy"** and **"infer_sample"** decoding, returns\
            a tuple `(outputs, sequence_lengths)`, where `outputs` is an \
            instance of :class:`~texar.modules.TransformerDecoderOutput` as\
            in "train_greedy", and `sequence_lengths` is a Tensor of shape\
            `[batch_size]` containing the length of each sample.

            - For **beam_search** decoding, returns a `dict` containing keys\
            "sample_id" and "log_prob".

                - **"sample_id"** is an int Tensor of shape \
                `[batch_size, max_time, beam_width]` containing generated\
                token indexes. `sample_id[:,:,0]` is the highest-probable \
                sample.
                - **"log_porb"** is a float Tensor of shape \
                `[batch_size, beam_width]` containing the log probability \
                of each sequence sample.
        """
        if memory_attention_bias is None:
            if memory_sequence_length is None:
                raise ValueError(
                    "`memory_sequence_length` is required if "
                    "`memory_attention_bias` is not given.")

            #enc_padding = 1 - mask_sequences(tf.ones_like(memory),
            #                                 memory_sequence_length,
            #                                 tensor_rank=3)[:, :, 0]
            enc_padding = 1 - tf.sequence_mask(
                memory_sequence_length, tf.shape(memory)[1], dtype=tf.float32)
            memory_attention_bias = attn.attention_bias_ignore_padding(
                enc_padding)

        if beam_width <= 1 and decoding_strategy == 'train_greedy':
            if sequence_length is not None:
                inputs = mask_sequences(inputs, sequence_length, tensor_rank=3)

            decoder_self_attention_bias = (
                attn.attention_bias_lower_triangle(
                    shape_list(inputs)[1]))
            target_inputs = inputs * self._hparams.dim**0.5

            _, lengths, channels = shape_list(target_inputs)
            pos_embeds = self.position_embedder(lengths, channels)

            inputs = target_inputs + pos_embeds

            decoder_output = self._self_attention_stack(
                inputs,
                memory,
                decoder_self_attention_bias=decoder_self_attention_bias,
                memory_attention_bias=memory_attention_bias,
                cache=None,
                mode=mode)
            logits = self.output_layer(decoder_output)
            preds = tf.to_int32(tf.argmax(logits, axis=-1))
            output = TransformerDecoderOutput(
                logits=logits,
                sample_id=preds
            )
            rets = output

        else: # Inference decoding

            if max_decoding_length is None:
                max_decoding_length = self._hparams.max_decoding_length

            if beam_width <= 1:
                logits, preds, sequence_length = self._infer_decoding(
                    self._prepare_tokens_to_embeds,
                    start_tokens,
                    end_token,
                    decode_length=max_decoding_length,
                    memory=memory,
                    memory_attention_bias=memory_attention_bias,
                    decoding_strategy=decoding_strategy,
                )
                output = TransformerDecoderOutput(
                    logits=logits,
                    sample_id=preds)
                rets = output, sequence_length
            else:
                # The output format is different when running beam search
                sample_id, log_prob = self._beam_decode(
                    self._prepare_tokens_to_embeds,
                    start_tokens,
                    end_token,
                    beam_width=beam_width,
                    alpha=alpha,
                    decode_length=max_decoding_length,
                    memory=memory,
                    memory_attention_bias=memory_attention_bias,
                )
                predictions = {
                    'sample_id':sample_id,
                    'log_prob': log_prob
                }
                rets = predictions

        if not self._built:
            self._add_internal_trainable_variables()
            self._built = True

        return rets
Exemple #2
0
    def forward(self,  # type: ignore
                inputs: Optional[torch.Tensor] = None,
                sequence_length: Optional[torch.LongTensor] = None,
                memory: Optional[torch.Tensor] = None,
                memory_sequence_length: Optional[torch.LongTensor] = None,
                memory_attention_bias: Optional[torch.Tensor] = None,
                context: Optional[torch.Tensor] = None,
                context_sequence_length: Optional[torch.LongTensor] = None,
                helper: Optional[Helper] = None,
                decoding_strategy: str = 'train_greedy',
                max_decoding_length: Optional[int] = None,
                impute_finished: bool = False,
                infer_mode: Optional[bool] = None,
                beam_width: Optional[int] = None,
                length_penalty: float = 0.,
                **kwargs) \
            -> Union[
                TransformerDecoderOutput,
                Tuple[TransformerDecoderOutput, torch.LongTensor],
                Dict[str, torch.Tensor]]:
        r"""Performs decoding.

        The interface is very similar to that of RNN decoders
        (:class:`texar.modules.RNNDecoderBase`). In particular,
        the function provides **3 ways** to specify the decoding method, with
        varying flexibility:

        1. The :attr:`decoding_strategy` argument.

           - **"train_greedy"**: decoding in teacher-forcing fashion (i.e.,
             feeding ground truth to decode the next step), and for each step
             sample is obtained by taking the `argmax` of logits.
             Argument :attr:`inputs` is required for this strategy.
             :attr:`sequence_length` is optional.
           - **"infer_greedy"**: decoding in inference fashion (i.e., feeding
             `generated` sample to decode the next step), and for each step
             sample is obtained by taking the `argmax` of logits.
             Arguments :attr:`(start_tokens, end_token)` are
             required for this strategy, and argument
             :attr:`max_decoding_length` is optional.
           - **"infer_sample"**: decoding in inference fashion, and for each
             step sample is obtained by `random sampling` from the logits.
             Arguments :attr:`(start_tokens, end_token)` are required for this
             strategy, and argument :attr:`max_decoding_length` is optional.

          This argument is used only when arguments :attr:`helper` and
          :attr:`beam_width` are both `None`.

        2. The :attr:`helper` argument: An instance of subclass of
           :class:`texar.modules.decoders.Helper`.
           This provides a superset of decoding strategies than above.
           The interface is the same as in RNN decoders.
           Please refer to :meth:`texar.modules.RNNDecoderBase.forward` for
           detailed usage and examples.

           Note that, here, though using a
           :class:`~texar.decoder.TrainingHelper` corresponding to the
           ``"train_greedy"`` strategy above, the implementation is *slower*
           than directly setting ``decoding_strategy="train_greedy"`` (though
           output results are the same).

           Argument :attr:`max_decoding_length` is optional.

        3. **Beam search**: set :attr:`beam_width` to use beam search decoding.
           Arguments :attr:`(start_tokens, end_token)` are required,
           and argument :attr:`max_decoding_length` is optional.

           .. warning::
               Beam search is not yet implemented. Setting :attr:`beam_width`
               to any value greater than 1 would raise a
               :exc:`NotImplementedError`

        Args:
            memory (optional): The memory to attend, e.g., the output of an RNN
                encoder. A :tensor:`Tensor` of shape
                ``[batch_size, memory_max_time, dim]``.
            memory_sequence_length (optional): A :tensor:`Tensor` of shape
                ``[batch_size]`` containing the sequence lengths for the batch
                entries in memory. Used to create attention bias of
                :attr:`memory_attention_bias` is not given. Ignored if
                :attr:`memory_attention_bias` is provided.
            memory_attention_bias (optional): A :tensor:`Tensor` of shape
                ``[batch_size, num_heads, memory_max_time, dim]``.
                An attention bias typically sets the value of a padding
                position to a large negative value for masking. If not given,
                :attr:`memory_sequence_length` is used to automatically
                create an attention bias.
            inputs (optional): Input tensor for teacher forcing decoding, of
                shape ``[batch_size, target_max_time, emb_dim]`` containing the
                target sequence word embeddings. Used when
                :attr:`decoding_strategy` is set to ``"train_greedy"``.
            sequence_length (optional): A :tensor:`LongTensor` of shape
                ``[batch_size]``, containing the sequence length of
                :attr:`inputs`. Tokens beyond the respective sequence length are
                masked out.
                Used when :attr:`decoding_strategy` is set to
                ``"train_greedy"``.
            decoding_strategy (str): A string specifying the decoding
                strategy, including ``"train_greedy"``, ``"infer_greedy"``,
                ``"infer_sample"``.
                Different arguments are required based on the
                strategy. See above for details. Ignored if
                :attr:`beam_width` or :attr:`helper` is set.
            beam_width (int): Set to use beam search. If given,
                :attr:`decoding_strategy` is ignored.
            length_penalty (float): Length penalty coefficient used in beam
                search decoding. Refer to https://arxiv.org/abs/1609.08144
                for more details.
                It should be larger if longer sentences are desired.
            context (optional): An :tensor:`LongTensor` of shape
                ``[batch_size, length]``, containing the starting tokens for
                decoding. If context is set, ``start_tokens`` of the
                :class:`~texar.modules.Helper` will be ignored.
            context_sequence_length (optional): Specify the length of context.
            max_decoding_length (int, optional): The maximum allowed number of
                decoding steps.
                If `None` (default), use ``"max_decoding_length"`` defined in
                :attr:`hparams`. Ignored in ``"train_greedy"`` decoding.
            impute_finished (bool): If `True`, then states for batch
                entries which are marked as finished get copied through and
                the corresponding outputs get zeroed out.  This causes some
                slowdown at each time step, but ensures that the final state
                and outputs have the correct values and that backprop ignores
                time steps that were marked as finished. Ignored in
                ``"train_greedy"`` decoding.
            helper (optional): An instance of
                :class:`texar.modules.decoders.Helper`
                that defines the decoding strategy. If given,
                ``decoding_strategy`` and helper configurations in
                :attr:`hparams` are ignored.
            infer_mode (optional): If not `None`, overrides mode given by
                :attr:`self.training`.

        Returns:

            - For **"train_greedy"** decoding, returns an instance of
              :class:`~texar.modules.TransformerDecoderOutput` which contains
              `sample_id` and `logits`.

            - For **"infer_greedy"** and **"infer_sample"** decoding or
              decoding with :attr:`helper`, returns
              a tuple ``(outputs, sequence_lengths)``, where ``outputs`` is an
              instance of :class:`~texar.modules.TransformerDecoderOutput` as
              in `"train_greedy"`, and ``sequence_lengths`` is a
              :tensor:`LongTensor` of shape ``[batch_size]`` containing the
              length of each sample.

            - For **beam search** decoding, returns a ``dict`` containing keys
              ``"sample_id"`` and ``"log_prob"``.

                - ``"sample_id"`` is a :tensor:`LongTensor` of shape
                  ``[batch_size, max_time, beam_width]`` containing generated
                  token indexes. ``sample_id[:,:,0]`` is the highest-probable
                  sample.
                - ``"log_prob"`` is a :tensor:`Tensor` of shape
                  ``[batch_size, beam_width]`` containing the log probability
                  of each sequence sample.
        """

        if memory is not None:
            if memory_attention_bias is None:
                if memory_sequence_length is None:
                    raise ValueError("`memory_sequence_length` is required if "
                                     "`memory_attention_bias` is not given.")

                enc_padding = 1 - sequence_mask(memory_sequence_length,
                                                memory.size(1),
                                                dtype=torch.float32)
                memory_attention_bias = attn.attention_bias_ignore_padding(
                    enc_padding)

        # record the context, which will be used in step function
        # for dynamic_decode
        if context is not None:
            if context_sequence_length is None:
                raise ValueError("'context_sequence_length' must not be None"
                                 "when 'context' is specified.")
            self._state_context = context[:, 1:]
            self._state_context_sequence_length = context_sequence_length - 1
        else:
            self._state_context = None
            self._state_context_sequence_length = None

        # Faster code path for teacher-forcing training
        if (helper is None and beam_width is None
                and decoding_strategy == 'train_greedy'):
            if inputs is None:
                raise ValueError(
                    "'input' must not be none "
                    "when using 'train_greedy' decoding strategy.")
            if sequence_length is not None:
                inputs = mask_sequences(inputs, sequence_length)

            decoder_self_attention_bias = (attn.attention_bias_lower_triangle(
                inputs.size(1)))

            decoder_output = self._self_attention_stack(
                inputs,
                memory,
                decoder_self_attention_bias,
                memory_attention_bias,
                cache=None)
            logits = self._output_layer(decoder_output)
            sample_id = torch.argmax(logits, dim=-1)

            return TransformerDecoderOutput(logits, sample_id)

        # Inference code path.
        if max_decoding_length is None:
            max_decoding_length = self._hparams.max_decoding_length

        self._state_max_decoding_length = max_decoding_length

        if beam_width is None or beam_width == 1:  # Inference-like decoding
            # Prepare helper
            if helper is None:
                kwargs.update(decoding_strategy=decoding_strategy)
                if context is not None:
                    kwargs.update(start_tokens=context[:, 0])
                helper = self._create_or_get_helper(infer_mode, **kwargs)
            assert isinstance(helper, EmbeddingHelper)

            self._state_cache = self._init_cache(memory,
                                                 memory_attention_bias,
                                                 beam_search_decoding=False,
                                                 batch_size=helper.batch_size)
            if context is not None:
                assert self._state_context is not None
                pad_length = max_decoding_length - self._state_context.size(1)
                if pad_length > 0:
                    self._state_context = torch.cat(
                        (self._state_context,
                         self._state_context.new_zeros(
                             self._state_context.size(0), pad_length)),
                        dim=1)

            outputs, cache, sequence_lengths = self.dynamic_decode(
                helper,
                inputs=None,
                sequence_length=None,
                initial_state=None,
                max_decoding_length=max_decoding_length,
                impute_finished=impute_finished)
            del cache  # not used

            if context is not None:
                # Here the length of sample_id will be larger than that
                # of logit by 1, because there will be a additional
                # start_token in the returned sample_id.
                # the start_id should be the first token of the
                # given context
                start_tokens = context[:, 0]
                outputs = TransformerDecoderOutput(
                    logits=outputs.logits,
                    sample_id=torch.cat(
                        [start_tokens.unsqueeze(1), outputs.sample_id], dim=1))
                sequence_lengths = sequence_lengths + 1

            return outputs, sequence_lengths

        else:  # Beam-search decoding
            # Ignore `decoding_strategy` and # assume `helper` is not set.
            if helper is not None:
                raise ValueError("Must not set 'beam_width' and 'helper' "
                                 "simultaneously.")
            if context is not None:
                start_tokens = context[:, 0]
            else:
                if 'start_tokens' not in kwargs:
                    raise ValueError(
                        "'start_tokens' must be specified when using"
                        "beam search decoding.")
                start_tokens = kwargs['start_tokens']
            _batch_size = start_tokens.size(0)
            self._state_cache = self._init_cache(memory,
                                                 memory_attention_bias,
                                                 beam_search_decoding=True,
                                                 batch_size=_batch_size)
            end_token: int = kwargs.get('end_token')  # type: ignore

            # The output format is different when running beam search.
            sample_id, log_prob = self._beam_decode(
                start_tokens,
                end_token,
                embedding_fn=kwargs['embedding'],
                beam_width=beam_width,
                length_penalty=length_penalty,
                decode_length=max_decoding_length)

            return {'sample_id': sample_id, 'log_prob': log_prob}
    def _build(self, inputs, sequence_length, mode=None):
        """Encodes the inputs.

        Args:
            inputs: A 3D Tensor of shape `[batch_size, max_time, dim]`,
                containing the word embeddings of input sequences. Note that
                the embedding dimension `dim` must equal "dim" in
                :attr:`hparams`.
            sequence_length: A 1D Tensor of shape `[batch_size]`. Input tokens
                beyond respective sequence lengths are masked out
                automatically.
            mode (optional): A tensor taking value in
                :tf_main:`tf.estimator.ModeKeys <estimator/ModeKeys>`,
                including `TRAIN`, `EVAL`, and `PREDICT`. Used to toggle
                dropout.
                If `None` (default), :func:`texar.global_mode` is used.

        Returns:
            A Tensor of shape `[batch_size, max_time, dim]` containing the
            encoded vectors.
        """
        # Multiply input embedding with the sqrt of its dimension for
        # normalization
        if not self._hparams.use_bert_config:
            inputs = inputs * self._hparams.dim**0.5
            inputs = mask_sequences(inputs, sequence_length, tensor_rank=3)
        _, lengths, _ = shape_list(inputs)

        inputs_padding = 1 - tf.sequence_mask(
            sequence_length, tf.shape(inputs)[1], dtype=tf.float32)
        if self._hparams.use_bert_config:
            ignore_padding = attn.attention_bias_ignore_padding(
                inputs_padding, bias_value=-1e4)
        else:
            ignore_padding = attn.attention_bias_ignore_padding(
                inputs_padding)

        encoder_self_attention_bias = ignore_padding

        positions = tf.expand_dims(tf.range(lengths, dtype=tf.int32), 0)
        pos_embeds = self.position_embedder(positions)

        input_embedding = inputs + pos_embeds

        if self._hparams.use_bert_config:
            x = layers.layer_normalize(input_embedding)
            x = tf.layers.dropout(x,
                                  rate=self._hparams.embedding_dropout,
                                  training=is_train_mode(mode))
        else:
            x = tf.layers.dropout(input_embedding,
                                  rate=self._hparams.embedding_dropout,
                                  training=is_train_mode(mode))

        # Just to keep consistent with BERT, actually makes no difference
        if self._hparams.use_bert_config:
            pad_remover = None
        else:
            pad_remover = utils.transformer_utils.PadRemover(inputs_padding)

        for i in range(self._hparams.num_blocks):
            with tf.variable_scope("layer_{}".format(i)):
                multihead_attention = self.multihead_attention_list[i]
                # trivial difference between BERT and original Transformer
                if self._hparams.use_bert_config:
                    _queries_input = x
                else:
                    _queries_input = layers.layer_normalize(x)

                attention_output = multihead_attention(
                    queries=_queries_input,
                    memory=_queries_input,
                    memory_attention_bias=encoder_self_attention_bias,
                    mode=mode,
                )
                attention_output = tf.layers.dropout(
                    attention_output,
                    rate=self._hparams.residual_dropout,
                    training=is_train_mode(mode),
                )
                x = x + attention_output
                with tf.variable_scope('output'):
                    if self._hparams.use_bert_config:
                        x = layers.layer_normalize(x)
                        y = x
                    else:
                        y = layers.layer_normalize(x)
                poswise_network = self.poswise_networks[i]
                with tf.variable_scope(poswise_network.variable_scope):
                    original_shape = shape_list(y)
                    y = tf.reshape(y, [-1, self._hparams.dim])
                    if pad_remover:
                        y = tf.expand_dims(pad_remover.remove(y), axis=0)
                        # [1, batch_size*seq_length, hidden_dim]
                    layer_output = poswise_network(y, mode=mode)
                    sub_output = tf.layers.dropout(
                        layer_output,
                        rate=self._hparams.residual_dropout,
                        training=is_train_mode(mode)
                    )
                    if pad_remover:
                        sub_output = tf.reshape(pad_remover.restore(tf.squeeze(\
                            sub_output, axis=0)), original_shape \
                        )
                    else:
                        sub_output = tf.reshape(sub_output, original_shape)

                    x = x + sub_output
                    if self._hparams.use_bert_config:
                        x = layers.layer_normalize(x)

        if not self._hparams.use_bert_config:
            x = layers.layer_normalize(x)

        if not self._built:
            self._add_internal_trainable_variables()
            self._built = True

        return x
    def _build(
            self,  # pylint: disable=arguments-differ
            memory=None,
            memory_sequence_length=None,
            memory_attention_bias=None,
            inputs=None,
            sequence_length=None,
            decoding_strategy='train_greedy',
            beam_width=None,
            length_penalty=0.,
            start_tokens=None,
            end_token=None,
            context=None,
            context_sequence_length=None,
            softmax_temperature=None,
            max_decoding_length=None,
            impute_finished=False,
            helper=None,
            mode=None):
        """Performs decoding.

        The interface is very similar to that of RNN decoders
        (:meth:`texar.modules.RNNDecoderBase._build`). In particular,
        the function provides **3 ways** to specify the decoding method, with
        varying flexibility:

        1. The :attr:`decoding_strategy` argument.

            - **"train_greedy"**: decoding in teacher-forcing fashion (i.e.,
              feeding ground truth to decode the next step), and for each step
              sample is obtained by taking the `argmax` of logits.
              Argument :attr:`inputs` is required for this strategy.
              :attr:`sequence_length` is optional.
            - **"infer_greedy"**: decoding in inference fashion (i.e., feeding
              `generated` sample to decode the next step), and for each step
              sample is obtained by taking the `argmax` of logits.
              Arguments :attr:`(start_tokens, end_token)` are
              required for this strategy, and argument
              :attr:`max_decoding_length` is optional.
            - **"infer_sample"**: decoding in inference fashion, and for each
              step sample is obtained by `random sampling` from the logits.
              Arguments :attr:`(start_tokens, end_token)` are required for this
              strategy, and argument :attr:`max_decoding_length` is optional.

          This argument is used only when arguments :attr:`helper` and
          :attr:`beam_width` are both `None`.

        2. The :attr:`helper` argument: An instance of subclass of
           :tf_main:`tf.contrib.seq2seq.Helper <contrib/seq2seq/Helper>`.
           This provides a superset of decoding strategies than above.
           The interface is the same as in RNN decoders.
           Please refer to :meth:`texar.modules.RNNDecoderBase._build` for
           detailed usage and examples.

           Note that, here, though using a :tf_main:`TrainingHelper
           <contrib/seq2seq/TrainingHelper>` corresponding to the
           "train_greedy" strategy above, the implementation is *slower* than
           directly setting `decoding_strategy="train_greedy"` (though the
           output results are the same).

           Argument :attr:`max_decoding_length` is optional.

        3. **Beam search**: set :attr:`beam_width` to use beam search decoding.
           Arguments :attr:`(start_tokens, end_token)` are required,
           and argument :attr:`max_decoding_length` is optional.

        Args:
            memory (optional): The memory to attend, e.g., the output of an RNN encoder.
                A Tensor of shape `[batch_size, memory_max_time, dim]`.
            memory_sequence_length (optional): A Tensor of shape `[batch_size]`
                containing the sequence lengths for the batch entries in
                memory. Used to create attention bias of
                :attr:`memory_attention_bias` is not given. Ignored if
                `memory_attention_bias` is provided.
            memory_attention_bias (optional): A Tensor of shape
                `[batch_size, num_heads, memory_max_time, dim]`.
                An attention bias typically sets the value of a padding
                position to a large negative value for masking. If not given,
                :attr:`memory_sequence_length` is used to automatically
                create an attention bias.
            inputs (optional): Input tensor for teacher forcing decoding, of
                shape `[batch_size, target_max_time, emb_dim]` containing the
                target sequence word embeddings.
                Used when :attr:`decoding_strategy` is set to "train_greedy".
            sequence_length (optional): A Tensor of shape `[batch_size]`,
                containing the sequence length of :attr:`inputs`.
                Tokens beyond the respective sequence length are masked out.
                Used when :attr:`decoding_strategy` is set to
                "train_greedy".
            decoding_strategy (str): A string specifying the decoding
                strategy, including "train_greedy", "infer_greedy",
                "infer_sample".
                Different arguments are required based on the
                strategy. See above for details. Ignored if
                :attr:`beam_width` or :attr:`helper` is set.
            beam_width (int): Set to use beam search. If given,
                :attr:`decoding_strategy` is ignored.
            length_penalty (float): Length penalty coefficient used in beam search
                decoding. Refer to https://arxiv.org/abs/1609.08144
                for more details.
                It Should be larger if longer sentences are wanted.
            start_tokens (optional): An int Tensor of shape `[batch_size]`,
                containing the start tokens.
                Used when :attr:`decoding_strategy` = "infer_greedy" or
                "infer_sample", or :attr:`beam_width` is set.
                Ignored when context is set.
            end_token (optional): An int 0D Tensor, the token that marks end
                of decoding.
                Used when :attr:`decoding_strategy` = "infer_greedy" or
                "infer_sample", or :attr:`beam_width` is set.
            context (optional): An int Tensor of shape `[batch_size, length]`,
                containing the starting tokens for decoding.
                If context is set, the start_tokens will be ignored.
            context_sequence_length (optional): specify the length of context.
            softmax_temperature (optional): A float 0D Tensor, value to divide
                the logits by before computing the softmax. Larger values
                (above 1.0) result in more random samples. Must > 0. If `None`,
                1.0 is used.
                Used when :attr:`decoding_strategy` = "infer_sample"`.
            max_decoding_length (optional): An int scalar Tensor indicating
                the maximum allowed number of decoding steps.
                If `None` (default), use "max_decoding_length" defined in
                :attr:`hparams`. Ignored in "train_greedy" decoding.
            impute_finished (bool): If `True`, then states for batch
                entries which are marked as finished get copied through and
                the corresponding outputs get zeroed out.  This causes some
                slowdown at each time step, but ensures that the final state
                and outputs have the correct values and that backprop ignores
                time steps that were marked as finished. Ignored in
                "train_greedy" decoding.
            helper (optional): An instance of
                :tf_main:`Helper <contrib/seq2seq/Helper>` that defines the
                decoding strategy. If given, :attr:`decoding_strategy` is
                ignored.
            mode (optional): A tensor taking value in
                :tf_main:`tf.estimator.ModeKeys <estimator/ModeKeys>`, including
                `TRAIN`, `EVAL`, and `PREDICT`. Controls dropout mode.
                If `None` (default), :func:`texar.global_mode`
                is used.

        Returns:

            - For **"train_greedy"** decoding, returns an instance of \
            :class:`~texar.modules.TransformerDecoderOutput` which contains\
            `sample_id` and `logits`.

            - For **"infer_greedy"** and **"infer_sample"** decoding or\
            decoding with :attr:`helper`, returns\
            a tuple `(outputs, sequence_lengths)`, where `outputs` is an \
            instance of :class:`~texar.modules.TransformerDecoderOutput` as\
            in "train_greedy", and `sequence_lengths` is a Tensor of shape\
            `[batch_size]` containing the length of each sample.

            - For **beam search** decoding, returns a `dict` containing keys\
            "sample_id" and "log_prob".

                - **"sample_id"** is an int Tensor of shape \
                `[batch_size, max_time, beam_width]` containing generated\
                token indexes. `sample_id[:,:,0]` is the highest-probable \
                sample.
                - **"log_prob"** is a float Tensor of shape \
                `[batch_size, beam_width]` containing the log probability \
                of each sequence sample.
        """

        if memory is not None:
            if memory_attention_bias is None:
                if memory_sequence_length is None:
                    raise ValueError("`memory_sequence_length` is required if "
                                     "`memory_attention_bias` is not given.")

                enc_padding = 1 - tf.sequence_mask(memory_sequence_length,
                                                   tf.shape(memory)[1],
                                                   dtype=tf.float32)
                memory_attention_bias = attn.attention_bias_ignore_padding(
                    enc_padding)

        # record the context, which will be used in step function
        # for dynamic_decode
        if context is not None:
            start_tokens = context[:, 0]
            self.context = context[:, 1:]
            self.context_sequence_length = context_sequence_length - 1
        else:
            self.context = None

        if helper is None and beam_width is None and \
                decoding_strategy == 'train_greedy': # Teacher-forcing
            if sequence_length is not None:
                inputs = mask_sequences(inputs, sequence_length, tensor_rank=3)

            decoder_self_attention_bias = (attn.attention_bias_lower_triangle(
                shape_list(inputs)[1]))
            if self._hparams.scale_embeds:
                target_inputs = inputs * self._hparams.dim**0.5
            else:
                target_inputs = inputs

            _, lengths, _ = shape_list(target_inputs)
            positions = tf.expand_dims(tf.range(lengths, dtype=tf.int32), 0)
            pos_embeds = self.position_embedder(positions)

            inputs = target_inputs + pos_embeds

            decoder_output = self._self_attention_stack(
                inputs,
                memory,
                decoder_self_attention_bias=decoder_self_attention_bias,
                memory_attention_bias=memory_attention_bias,
                cache=None,
                mode=mode)
            logits = self.output_layer(decoder_output)
            preds = tf.to_int32(tf.argmax(logits, axis=-1))
            rets = TransformerDecoderOutput(logits=logits, sample_id=preds)

        else:
            if max_decoding_length is None:
                max_decoding_length = self._hparams.max_decoding_length

            self._inputs_to_outputs = self._inputs_to_outputs_fn(
                max_decoding_length + 1)

            if beam_width is None:  #Inference-like decoding
                # Prepare helper
                if helper is not None:
                    # ignore `decoding_strategy`
                    pass
                else:
                    if decoding_strategy == "infer_greedy":
                        helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
                            self._embedding, start_tokens, end_token)
                    elif decoding_strategy == "infer_sample":
                        helper = tf.contrib.seq2seq.SampleEmbeddingHelper(
                            self._embedding, start_tokens, end_token,
                            softmax_temperature)
                    else:
                        raise ValueError(
                            "Unknown decoding strategy: {}".format(
                                decoding_strategy))
                self._helper = helper

                self._cache = self._init_cache(memory,
                                               memory_attention_bias,
                                               beam_search_decoding=False)
                if context is not None:
                    self.context = tf.pad(
                        self.context,
                        [[0, 0],
                         [0, max_decoding_length - tf.shape(self.context)[1]]])

                outputs, cache, sequence_lengths = dynamic_decode(
                    decoder=self,
                    impute_finished=impute_finished,
                    maximum_iterations=max_decoding_length,
                    output_time_major=False,
                    scope=self.variable_scope)

                if context is not None:
                    # Here the length of sample_id will be larger than that
                    # of logit by 1, because there will be a additional
                    # start_token in the returned sample_id.
                    # the start_id should be the first token of the
                    # given context
                    outputs = TransformerDecoderOutput(
                        logits=outputs.logits,
                        sample_id=tf.concat([
                            tf.expand_dims(start_tokens, 1), outputs.sample_id
                        ],
                                            axis=1))
                    sequence_lengths = sequence_lengths + 1
                rets = outputs, sequence_lengths

            else:  #Beam-search decoding
                # ignore `decoding_strategy`
                # assume `helper` is not set
                if helper is not None:
                    raise ValueError("Must not set 'beam_width' and 'helper' "
                                     "simultaneously.")
                _batch_size = tf.shape(start_tokens)[0]
                self._cache = self._init_cache(memory,
                                               memory_attention_bias,
                                               beam_search_decoding=True,
                                               batch_size=_batch_size)

                # The output format is different when running beam search
                sample_id, log_prob = self._beam_decode(
                    start_tokens,
                    end_token,
                    beam_width=beam_width,
                    length_penalty=length_penalty,
                    decode_length=max_decoding_length,
                )
                rets = {'sample_id': sample_id, 'log_prob': log_prob}

        if not self._built:
            self._add_internal_trainable_variables()
            self._built = True

        return rets
    def _build(self, inputs, sequence_length, mode=None):
        """Encodes the inputs.

        Args:
            inputs: A 3D Tensor of shape `[batch_size, max_time, dim]`,
                containing the word embeddings of input sequences. Note that
                the embedding dimension `dim` must equal "dim" in
                :attr:`hparams`.
            sequence_length: A 1D Tensor of shape `[batch_size]`. Input tokens
                beyond respective sequence lengths are masked out
                automatically.
            mode (optional): A tensor taking value in
                :tf_main:`tf.estimator.ModeKeys <estimator/ModeKeys>`,
                including `TRAIN`, `EVAL`, and `PREDICT`. Used to toggle
                dropout.
                If `None` (default), :func:`texar.global_mode` is used.

        Returns:
            A Tensor of shape `[batch_size, max_time, dim]` containing the
            encoded vectors.
        """
        # Multiply input embedding with the sqrt of its dimension for
        # normalization
        inputs = inputs * self._hparams.dim**0.5

        inputs = mask_sequences(inputs, sequence_length, tensor_rank=3)

        _, lengths, _ = shape_list(inputs)

        inputs_padding = 1 - tf.sequence_mask(
            sequence_length, tf.shape(inputs)[1], dtype=tf.float32)
        ignore_padding = attn.attention_bias_ignore_padding(inputs_padding)
        encoder_self_attention_bias = ignore_padding

        pos_embeds = self.position_embedder(lengths, self._hparams.dim)
        input_embedding = inputs + pos_embeds

        x = tf.layers.dropout(input_embedding,
                              rate=self._hparams.embedding_dropout,
                              training=is_train_mode(mode))
        pad_remover = utils.transformer_utils.PadRemover(inputs_padding)

        for i in range(self._hparams.num_blocks):
            with tf.variable_scope("layer_{}".format(i)):
                with tf.variable_scope('self_attention'):
                    selfatt_output = attn.multihead_attention(
                        queries=layers.layer_normalize(x),
                        memory=None,
                        memory_attention_bias=encoder_self_attention_bias,
                        num_heads=self._hparams.num_heads,
                        dropout_rate=self._hparams.attention_dropout,
                        num_units=self._hparams.dim,
                        scope='multihead_attention')
                    x = x + tf.layers.dropout(
                        selfatt_output,
                        rate=self._hparams.residual_dropout,
                        training=is_train_mode(mode),
                    )

                poswise_network = FeedForwardNetwork(
                    hparams=self._hparams['poswise_feedforward'])
                with tf.variable_scope(poswise_network.variable_scope):
                    y = layers.layer_normalize(x)
                    original_shape = shape_list(y)
                    y = tf.reshape(y, [-1, self._hparams.dim])
                    y = tf.expand_dims(pad_remover.remove(y), axis=0)
                    # [1, batch_size*seq_length, hidden_dim]
                    sub_output = tf.layers.dropout(
                        poswise_network(y),
                        rate=self._hparams.residual_dropout,
                        training=is_train_mode(mode))
                    sub_output = tf.reshape(pad_remover.restore(tf.squeeze(\
                        sub_output, axis=0)), original_shape \
                    )
                    x = x + sub_output

        encoder_output = layers.layer_normalize(x)

        if not self._built:
            self._add_internal_trainable_variables()
            self._built = True

        return encoder_output
Exemple #6
0
    def forward(self,  # type: ignore # pylint: disable=arguments-differ
                inputs: torch.Tensor,
                sequence_length: torch.LongTensor) \
            -> torch.Tensor:
        r"""Encodes the inputs.

        Args:
            inputs: A 3D Tensor of shape ``[batch_size, max_time, dim]``,
                containing the embedding of input sequences. Note that
                the embedding dimension `dim` must equal "dim" in
                :attr:`hparams`. The input embedding is typically an
                aggregation of word embedding and position embedding.
            sequence_length: A 1D :tensor:`LongTensor` of shape
                ``[batch_size]``. Input tokens beyond respective sequence
                lengths are masked out automatically.

        Returns:
            A Tensor of shape ``[batch_size, max_time, dim]`` containing the
            encoded vectors.
        """
        # Multiply input embedding with the sqrt of its dimension for
        # normalization

        inputs_padding = 1 - sequence_mask(sequence_length,
                                           inputs.size()[1]).float()
        if self._hparams.use_bert_config:
            ignore_padding = attn.attention_bias_ignore_padding(
                inputs_padding, bias_value=-1e4)
        else:
            ignore_padding = attn.attention_bias_ignore_padding(inputs_padding)
        encoder_self_attention_bias = ignore_padding

        input_embedding = inputs
        if self._hparams.use_bert_config:
            x = self.input_normalizer(input_embedding)
            x = self.embed_dropout(x)
        else:
            x = self.embed_dropout(input_embedding)

        for i in range(self._hparams.num_blocks):
            # trivial difference between BERT and original Transformer
            if self._hparams.use_bert_config:
                _queries_input = x
            else:
                _queries_input = self.self_attn_layer_norm[i](x)

            attention_output = self.self_attns[i](
                queries=_queries_input,
                memory=_queries_input,
                memory_attention_bias=encoder_self_attention_bias,
            )

            attention_output = self.residual_dropout(attention_output)

            x = x + attention_output

            poswise_network = self.poswise_networks[i]
            poswise_normalizer = self.poswise_layer_norm[i]

            if self._hparams.use_bert_config:
                x = poswise_normalizer(x)
                y = x
            else:
                y = poswise_normalizer(x)

            original_shape = y.size()

            y = y.view(-1, self._hparams.dim)

            layer_output = poswise_network(y)
            sub_output = self.residual_dropout(layer_output)
            sub_output = sub_output.view(original_shape)

            x = x + sub_output
            if self._hparams.use_bert_config:
                x = self.output_layer_norm[i](x)

        if not self._hparams.use_bert_config:
            x = self.final_layer_normalizer(x)
        return x
    def _build(
            self,  # pylint: disable=arguments-differ, too-many-statements
            decoding_strategy='train_greedy',
            inputs=None,
            adjs=None,
            memory=None,
            memory_sequence_length=None,
            memory_attention_bias=None,
            beam_width=None,
            length_penalty=0.,
            start_tokens=None,
            end_token=None,
            context=None,
            context_sequence_length=None,
            softmax_temperature=None,
            max_decoding_length=None,
            impute_finished=False,
            embedding=None,
            helper=None,
            mode=None):
        """Performs decoding.

        See 'Texar.modules.decoders.transformer_decoders.TransformerDecoder' for details

        adjs: A 3D Tensor of shape `[batch_size, max_time, max_time]`,
                containing the adjacency matrices of input sequences
        """
        # Get adjacency masks from adjs
        self.adj_masks = 1 - tf.cast(tf.equal(adjs, 0), dtype=tf.float32)

        if memory is not None:
            if memory_attention_bias is None:
                if memory_sequence_length is None:
                    raise ValueError("`memory_sequence_length` is required if "
                                     "`memory_attention_bias` is not given.")

                enc_padding = 1 - tf.sequence_mask(memory_sequence_length,
                                                   shape_list(memory)[1],
                                                   dtype=tf.float32)
                memory_attention_bias = attn.attention_bias_ignore_padding(
                    enc_padding)

        # record the context, which will be used in step function
        # for dynamic_decode
        if context is not None:
            start_tokens = context[:, 0]
            self.context = context[:, 1:]
            self.context_sequence_length = context_sequence_length - 1
        else:
            self.context = None

        self.embedding = embedding

        if helper is None and beam_width is None and \
                decoding_strategy == 'train_greedy':  # Teacher-forcing

            decoder_self_attention_bias = (attn.attention_bias_lower_triangle(
                shape_list(inputs)[1]))

            decoder_output = self._self_attention_stack(
                inputs,
                memory,
                decoder_self_attention_bias=decoder_self_attention_bias,
                memory_attention_bias=memory_attention_bias,
                cache=None,
                mode=mode)
            logits = self._output_layer(decoder_output)
            preds = tf.to_int32(tf.argmax(logits, axis=-1))
            rets = TransformerDecoderOutput(logits=logits, sample_id=preds)

        else:
            if max_decoding_length is None:
                max_decoding_length = self._hparams.max_decoding_length
            self.max_decoding_length = max_decoding_length
            if beam_width is None:  # Inference-like decoding
                # Prepare helper
                if helper is None:
                    if decoding_strategy == "infer_greedy":
                        helper = tx_helper.GreedyEmbeddingHelper(
                            embedding, start_tokens, end_token)
                    elif decoding_strategy == "infer_sample":
                        helper = tx_helper.SampleEmbeddingHelper(
                            embedding, start_tokens, end_token,
                            softmax_temperature)
                    else:
                        raise ValueError(
                            "Unknown decoding strategy: {}".format(
                                decoding_strategy))
                self._helper = helper

                self._cache = self._init_cache(memory,
                                               memory_attention_bias,
                                               beam_search_decoding=False)
                if context is not None:
                    self.context = tf.pad(self.context, [[
                        0, 0
                    ], [0, max_decoding_length - shape_list(self.context)[1]]])

                outputs, _, sequence_lengths = dynamic_decode(
                    decoder=self,
                    impute_finished=impute_finished,
                    maximum_iterations=max_decoding_length,
                    output_time_major=False,
                    scope=self.variable_scope)

                if context is not None:
                    # Here the length of sample_id will be larger than that
                    # of logit by 1, because there will be a additional
                    # start_token in the returned sample_id.
                    # the start_id should be the first token of the
                    # given context
                    outputs = TransformerDecoderOutput(
                        logits=outputs.logits,
                        sample_id=tf.concat([
                            tf.expand_dims(start_tokens, 1), outputs.sample_id
                        ],
                                            axis=1))
                    sequence_lengths = sequence_lengths + 1
                rets = outputs, sequence_lengths

            else:  # Beam-search decoding
                # Ignore `decoding_strategy`; Assume `helper` is not set
                if helper is not None:
                    raise ValueError("Must not set 'beam_width' and 'helper' "
                                     "simultaneously.")
                _batch_size = shape_list(start_tokens)[0]
                self._cache = self._init_cache(memory,
                                               memory_attention_bias,
                                               beam_search_decoding=True,
                                               batch_size=_batch_size)

                # The output format is different when running beam search
                sample_id, log_prob = self._beam_decode(
                    start_tokens,
                    end_token,
                    beam_width=beam_width,
                    length_penalty=length_penalty,
                    decode_length=max_decoding_length,
                )
                rets = {'sample_id': sample_id, 'log_prob': log_prob}

        if not self._built:
            self._add_internal_trainable_variables()
            self._built = True

        return rets
    def _build(self,
               inputs,
               memory,
               sequence_length,
               memory_sequence_length,
               adjs,
               encoder_output,
               mode=None):
        """Encodes the inputs.

        Args:
            inputs: A 3D Tensor of shape `[batch_size, max_time, dim]`,
                containing the embedding of input sequences. Note that
                the embedding dimension `dim` must equal "dim" in
                :attr:`hparams`. The input embedding is typically an aggregation
                of word embedding and position embedding.
            memory: A 3D Tensor of shape `[batch_size, memory_max_time, dim]`, 
                containing the embedding of memory sequences. Note that
                the embedding dimension `dim` must equal "dim" in
                :attr:`hparams`. The input embedding is typically an aggregation
                of word embedding and position embedding.
            sequence_length: A 1D Tensor of shape `[batch_size]`. Input tokens
                beyond respective sequence lengths are masked out
                automatically.
            sequence_length: A 1D Tensor of shape `[batch_size]`. Memory tokens
                beyond respective sequence lengths are masked out
                automatically.
            adjs: A 3D Tensor of shape `[batch_size, max_time, max_time]`,
                containing the adjacency matrices of input sequences
            encoder_output: bool. True: return encoder-like embeddings. False: return CrossGraphTransformerDecoderOutput. 
            mode (optional): A tensor taking value in
                :tf_main:`tf.estimator.ModeKeys <estimator/ModeKeys>`,
                including `TRAIN`, `EVAL`, and `PREDICT`. Used to toggle
                dropout.
                If `None` (default), :func:`texar.global_mode` is used.

        Returns:
            A Tensor of shape `[batch_size, max_time, dim]` containing the
            encoded vectors.
        """
        # Get adjacency masks from adjs
        adj_masks = 1 - tf.cast(tf.equal(adjs, 0), dtype=tf.float32)

        # Multiply input embedding with the sqrt of its dimension for
        # normalization
        inputs_padding = 1 - tf.sequence_mask(
            sequence_length, tf.shape(inputs)[1], dtype=tf.float32)
        if self._hparams.use_bert_config:
            ignore_padding = attn.attention_bias_ignore_padding(
                inputs_padding, bias_value=-1e4)
        else:
            ignore_padding = attn.attention_bias_ignore_padding(inputs_padding)
        encoder_self_attention_bias = ignore_padding

        input_embedding = inputs  # shape (batch_size, max_time, dim)

        if self._hparams.use_bert_config:
            x = layers.layer_normalize(input_embedding)
            x = tf.layers.dropout(x,
                                  rate=self._hparams.embedding_dropout,
                                  training=is_train_mode(mode))
        else:
            x = tf.layers.dropout(input_embedding,
                                  rate=self._hparams.embedding_dropout,
                                  training=is_train_mode(mode))

        # Just to keep consistent with BERT, actually makes no difference
        if self._hparams.use_bert_config:
            pad_remover = None
        else:
            pad_remover = utils.transformer_utils.PadRemover(inputs_padding)

        for i in range(self._hparams.num_blocks):
            with tf.variable_scope("layer_{}".format(i)):
                graph_multihead_attention = self.graph_multihead_attention_list[
                    i]

                # trivial difference between BERT and original Transformer
                if self._hparams.use_bert_config:
                    _queries_input = x
                else:
                    _queries_input = layers.layer_normalize(x)

                attention_output = graph_multihead_attention(
                    queries=_queries_input,
                    memory=memory,
                    adj_masks=adj_masks,
                    memory_attention_bias=encoder_self_attention_bias,
                    mode=mode,
                )
                attention_output = tf.layers.dropout(
                    attention_output,
                    rate=self._hparams.residual_dropout,
                    training=is_train_mode(mode),
                )
                # attention_output: weighted sum of V of memory with weights determined by querying keys of memory
                x = x + attention_output
                with tf.variable_scope('output'):
                    if self._hparams.use_bert_config:
                        x = layers.layer_normalize(x)
                        y = x
                    else:
                        y = layers.layer_normalize(x)

                poswise_network = self.poswise_networks[i]
                with tf.variable_scope(poswise_network.variable_scope):
                    original_shape = shape_list(y)
                    y = tf.reshape(y, [-1, self._hparams.dim])
                    if pad_remover:
                        y = tf.expand_dims(pad_remover.remove(y), axis=0)
                        # [1, batch_size*seq_length, hidden_dim]
                    layer_output = poswise_network(y, mode=mode)
                    sub_output = tf.layers.dropout(
                        layer_output,
                        rate=self._hparams.residual_dropout,
                        training=is_train_mode(mode))
                    if pad_remover:
                        sub_output = tf.reshape(pad_remover.restore(tf.squeeze(\
                            sub_output, axis=0)), original_shape \
                        )
                    else:
                        sub_output = tf.reshape(sub_output, original_shape)

                    x = x + sub_output
                    if self._hparams.use_bert_config:
                        x = layers.layer_normalize(x)

        if not self._hparams.use_bert_config:
            x = layers.layer_normalize(x)

        if not self._built:
            self._add_internal_trainable_variables()
            self._built = True

        if encoder_output:
            return x

        logits = self._output_layer(x)
        sample_ids = tf.to_int32(tf.argmax(logits, axis=-1))
        probs = ''
        # probs = GumbelSoftmax(self._tau, logits=logits).sample()
        # probs = tf.nn.softmax(logits / self._tau) # vanilla softmax

        rets = CrossGraphTransformerFixedLengthDecoderOutput(
            logits=logits, sample_id=sample_ids, probs=probs)

        return rets