Example #1
0
    def _build(self, inputs, mode=None):
        """

        Args:
            inputs:

        Returns:
        """
        training = is_train_mode(mode)

        prev_outputs = inputs
        for layer_id, layer in enumerate(self._layers):
            if isinstance(layer, tf.layers.Dropout) or \
                    isinstance(layer, tf.layers.BatchNormalization):
                outputs = layer(prev_outputs, training=training)
            else:
                outputs = layer(prev_outputs)
            self._layer_outputs.append(outputs)
            self._layer_outputs_by_name[self._layer_names[layer_id]] = outputs
            prev_outputs = outputs

        if not self._built:
            self._add_internal_trainable_variables()
            # Add trainable variables of `self._layers` which may be
            # constructed externally.
            for layer in self._layers:
                self._add_trainable_variable(layer.trainable_variables)
            self._built = True

        return outputs
Example #2
0
def _forward_output_layers(inputs,
                           input_size,
                           output_layer,
                           time_major,
                           hparams,
                           mode,
                           sequence_length=None):
    """Forwards inputs through the output layers.

    Args:
        inputs: A Tensor of shape `[batch_size, max_time] + input_size` if
            :attr:`time_major=False`, or shape
            `[max_time, batch_size] + input_size` if :attr:`time_major=True`.

    Returns:
        A pair :attr:`(outputs, outputs_size), where

        - :attr:`outputs`: A Tensor of shape \
          `[batch_size, max_time] + outputs_size`.

        - :attr:`outputs_size`: An `int` or 1D `int` array representing the \
          output size.
    """
    if output_layer is None:
        return inputs, input_size

    if hparams is None:
        # output_layer was passed in from the constructor
        if isinstance(output_layer, (list, tuple)):
            raise ValueError('output_layer must not be a list or tuple.')
        output, output_size = _forward_single_output_layer(
            inputs, input_size, output_layer)
    else:
        # output_layer was built based on hparams
        output_layer = _to_list(output_layer)

        dropout_layer_ids = _to_list(hparams.dropout_layer_ids)
        if len(dropout_layer_ids) > 0:
            training = is_train_mode(mode)

        output = inputs
        output_size = input_size
        for i, layer in enumerate(output_layer):
            if i in dropout_layer_ids:
                output = _apply_dropout(output, time_major, hparams, training)
            output, output_size = _forward_single_output_layer(
                output, output_size, layer)

        if len(output_layer) in dropout_layer_ids:
            output = _apply_dropout(output, time_major, hparams, training)

    if sequence_length is not None:
        output = mask_sequences(output,
                                sequence_length,
                                time_major=time_major,
                                tensor_rank=3)

    return output, output_size
    def test_mode(self):
        """ Tests mode related utilities.
        """
        training = mode.is_train_mode(None)
        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            training_ = sess.run(training)
            self.assertTrue(training_)

            training_ = sess.run(
                training,
                feed_dict={context.global_mode(): tf.estimator.ModeKeys.TRAIN})
            self.assertTrue(training_)

            training_ = sess.run(
                training,
                feed_dict={context.global_mode(): tf.estimator.ModeKeys.EVAL})
            self.assertFalse(training_)

        training = mode.is_train_mode(tf.estimator.ModeKeys.TRAIN)
        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            training_ = sess.run(training)
            self.assertTrue(training_)
Example #4
0
    def call(self, inputs, mode=None):  # pylint: disable=arguments-differ
        training = is_train_mode(mode)

        outputs = inputs
        for layer in self._layers:
            if isinstance(layer, tf.layers.Dropout) or \
                    isinstance(layer, tf.layers.BatchNormalization):
                outputs = layer(outputs, training=training)
            else:
                outputs = layer(inputs)
            inputs = outputs

        if not self.built:
            self._collect_weights()

        return outputs
Example #5
0
    def _build(self, inputs, mode=None):
        """Feeds forward inputs through the network layers and returns outputs.

        Args:
            inputs: The inputs to the network. The requirements on inputs
                depends on the first layer and subsequent layers in the
                network.
            mode (optional): A tensor taking value in
                :tf_main:`tf.estimator.ModeKeys <estimator/ModeKeys>`, including
                `TRAIN`, `EVAL`, and `PREDICT`. If `None`,
                :func:`texar.global_mode` is used.

        Returns:
            The output of the network.
        """
        training = is_train_mode(mode)

        prev_outputs = inputs
        for layer_id, layer in enumerate(self._layers):
            if isinstance(layer, tf.layers.Dropout) or \
                    isinstance(layer, tf.layers.BatchNormalization):
                outputs = layer(prev_outputs, training=training)
            else:
                outputs = layer(prev_outputs)
            self._layer_outputs.append(outputs)
            self._layer_outputs_by_name[self._layer_names[layer_id]] = outputs
            prev_outputs = outputs

        if not self._built:
            self._add_internal_trainable_variables()
            # Add trainable variables of `self._layers` which may be
            # constructed externally.
            for layer in self._layers:
                self._add_trainable_variable(layer.trainable_variables)
            self._built = True

        return outputs
    def _self_attention_stack(self,
                              inputs,
                              memory,
                              decoder_self_attention_bias=None,
                              memory_attention_bias=None,
                              cache=None,
                              mode=None):
        """Stacked multihead attention module.
        """
        inputs = tf.layers.dropout(inputs,
                                   rate=self._hparams.embedding_dropout,
                                   training=is_train_mode(mode))
        if cache is not None:
            memory_attention_bias = \
                cache['memory_attention_bias']
        else:
            assert decoder_self_attention_bias is not None

        x = inputs
        for i in range(self._hparams.num_blocks):
            layer_name = 'layer_{}'.format(i)
            layer_cache = cache[layer_name] if cache is not None else None
            with tf.variable_scope(layer_name):
                with tf.variable_scope("self_attention"):
                    multihead_attention = \
                        self.multihead_attentions['self_att'][i]
                    selfatt_output = multihead_attention(
                        queries=layers.layer_normalize(x),
                        memory=None,
                        memory_attention_bias=decoder_self_attention_bias,
                        cache=layer_cache,
                        mode=mode,
                    )
                    x = x + tf.layers.dropout(
                        selfatt_output,
                        rate=self._hparams.residual_dropout,
                        training=is_train_mode(mode),
                    )
                if memory is not None:
                    with tf.variable_scope('encdec_attention'):
                        multihead_attention = \
                            self.multihead_attentions['encdec_att'][i]
                        encdec_output = multihead_attention(
                            queries=layers.layer_normalize(x),
                            memory=memory,
                            memory_attention_bias=memory_attention_bias,
                            mode=mode,
                        )
                        x = x + tf.layers.dropout(encdec_output, \
                            rate=self._hparams.residual_dropout, \
                            training=is_train_mode(mode))
                poswise_network = self.poswise_networks[i]
                with tf.variable_scope('past_poswise_ln'):
                    sub_output = tf.layers.dropout(
                        poswise_network(layers.layer_normalize(x)),
                        rate=self._hparams.residual_dropout,
                        training=is_train_mode(mode),
                    )
                    x = x + sub_output

        return layers.layer_normalize(x)
    def _self_attention_stack(self,
                              inputs,
                              memory,
                              decoder_self_attention_bias=None,
                              memory_attention_bias=None,
                              cache=None,
                              mode=None):
        """Stacked multihead attention module.
        """
        inputs = tf.layers.dropout(inputs,
                                   rate=self._hparams.embedding_dropout,
                                   training=is_train_mode(mode))
        if cache is not None:
            memory_attention_bias = \
                cache['memory_attention_bias']
        else:
            assert decoder_self_attention_bias is not None

        x = inputs
        for i in range(self._hparams.num_blocks):
            layer_name = 'layer_{}'.format(i)
            layer_cache = cache[layer_name] if cache is not None else None
            with tf.variable_scope(layer_name):
                with tf.variable_scope("self_attention"):
                    selfatt_output = attn.multihead_attention(
                        queries=layers.layer_normalize(x),
                        memory=None,
                        memory_attention_bias=decoder_self_attention_bias,
                        num_units=self._hparams.dim,
                        num_heads=self._hparams.num_heads,
                        dropout_rate=self._hparams.attention_dropout,
                        cache=layer_cache,
                        scope="multihead_attention",
                    )
                    x = x + tf.layers.dropout(
                        selfatt_output,
                        rate=self._hparams.residual_dropout,
                        training=is_train_mode(mode),
                    )
                if memory is not None:
                    with tf.variable_scope('encdec_attention'):
                        encdec_output = attn.multihead_attention(
                            queries=layers.layer_normalize(x),
                            memory=memory,
                            memory_attention_bias=memory_attention_bias,
                            num_units=self._hparams.dim,
                            num_heads=self._hparams.num_heads,
                            dropout_rate=self._hparams.attention_dropout,
                            scope="multihead_attention"
                        )
                        x = x + tf.layers.dropout(encdec_output, \
                            rate=self._hparams.residual_dropout, \
                            training=is_train_mode(mode))
                poswise_network = FeedForwardNetwork( \
                    hparams=self._hparams['poswise_feedforward'])
                with tf.variable_scope(poswise_network.variable_scope):
                    sub_output = tf.layers.dropout(
                        poswise_network(layers.layer_normalize(x)),
                        rate=self._hparams.residual_dropout,
                        training=is_train_mode(mode),
                    )
                    x = x + sub_output

        return layers.layer_normalize(x)
Example #8
0
    def _build(self,
               positions=None,
               sequence_length=None,
               mode=None,
               **kwargs):
        """Embeds the positions.

        Either :attr:`position` or :attr:`sequence_length` is required:

            - If both are given, :attr:`sequence_length` is used to mask out \
            embeddings of those time steps beyond the respective sequence \
            lengths.
            - If only :attr:`sequence_length` is given, then positions \
            from `0` to `sequence_length-1` are embedded.

        Args:
            positions (optional): An integer tensor containing the position
                ids to embed.
            sequence_length (optional): An integer tensor of shape
                `[batch_size]`. Time steps beyond
                the respective sequence lengths will have zero-valued
                embeddings.
            mode (optional): A tensor taking value in
                :tf_main:`tf.estimator.ModeKeys <estimator/ModeKeys>`, including
                `TRAIN`, `EVAL`, and `PREDICT`. If `None`, dropout will be
                controlled by :func:`texar.global_mode`.
            kwargs: Additional keyword arguments for
                :tf_main:`tf.nn.embedding_lookup <nn/embedding_lookup>` besides
                :attr:`params` and :attr:`ids`.

        Returns:
            A `Tensor` of shape `shape(inputs) + embedding dimension`.
        """
        # Gets embedder inputs
        inputs = positions
        if positions is None:
            if sequence_length is None:
                raise ValueError(
                    'Either `positions` or `sequence_length` is required.')
            max_length = tf.reduce_max(sequence_length)
            single_inputs = tf.range(start=0, limit=max_length, dtype=tf.int32)
            # Expands `single_inputs` to have shape [batch_size, max_length]
            expander = tf.expand_dims(tf.ones_like(sequence_length), -1)
            inputs = expander * tf.expand_dims(single_inputs, 0)
        ids_rank = len(inputs.shape.dims)

        embedding = self._embedding

        is_training = is_train_mode(mode)

        # Gets dropout strategy
        st = self._hparams.dropout_strategy
        if positions is None and st == 'item':
            # If `inputs` is based on `sequence_length`, then dropout
            # strategies 'item' and 'item_type' have the same effect, we
            # use 'item_type' to avoid unknown noise_shape in the 'item'
            # strategy
            st = 'item_type'

        # Dropouts as 'item_type' before embedding
        if st == 'item_type':
            dropout_layer = self._get_dropout_layer(self._hparams,
                                                    dropout_strategy=st)
            if dropout_layer:
                embedding = dropout_layer.apply(inputs=embedding,
                                                training=is_training)

        # Embeds
        outputs = tf.nn.embedding_lookup(embedding, inputs, **kwargs)

        # Dropouts as 'item' or 'elements' after embedding
        if st != 'item_type':
            dropout_layer = self._get_dropout_layer(self._hparams,
                                                    ids_rank=ids_rank,
                                                    dropout_input=outputs,
                                                    dropout_strategy=st)
            if dropout_layer:
                outputs = dropout_layer.apply(inputs=outputs,
                                              training=is_training)

        # Optionally masks
        if sequence_length is not None:
            outputs = mask_sequences(outputs,
                                     sequence_length,
                                     tensor_rank=len(inputs.shape.dims) +
                                     self._dim_rank)

        return outputs
Example #9
0
    def _build(self, inputs, sequence_length, mode=None):
        """Encodes the inputs.

        Args:
            inputs: A 3D Tensor of shape `[batch_size, max_time, dim]`,
                containing the word embeddings of input sequences. Note that
                the embedding dimension `dim` must equal "dim" in
                :attr:`hparams`.
            sequence_length: A 1D Tensor of shape `[batch_size]`. Input tokens
                beyond respective sequence lengths are masked out
                automatically.
            mode (optional): A tensor taking value in
                :tf_main:`tf.estimator.ModeKeys <estimator/ModeKeys>`,
                including `TRAIN`, `EVAL`, and `PREDICT`. Used to toggle
                dropout.
                If `None` (default), :func:`texar.global_mode` is used.

        Returns:
            A Tensor of shape `[batch_size, max_time, dim]` containing the
            encoded vectors.
        """
        # Multiply input embedding with the sqrt of its dimension for
        # normalization
        if not self._hparams.use_bert_config:
            inputs = inputs * self._hparams.dim**0.5
            inputs = mask_sequences(inputs, sequence_length, tensor_rank=3)
        _, lengths, _ = shape_list(inputs)

        inputs_padding = 1 - tf.sequence_mask(
            sequence_length, tf.shape(inputs)[1], dtype=tf.float32)
        if self._hparams.use_bert_config:
            ignore_padding = attn.attention_bias_ignore_padding(
                inputs_padding, bias_value=-1e4)
        else:
            ignore_padding = attn.attention_bias_ignore_padding(
                inputs_padding)

        encoder_self_attention_bias = ignore_padding

        positions = tf.expand_dims(tf.range(lengths, dtype=tf.int32), 0)
        pos_embeds = self.position_embedder(positions)

        input_embedding = inputs + pos_embeds

        if self._hparams.use_bert_config:
            x = layers.layer_normalize(input_embedding)
            x = tf.layers.dropout(x,
                                  rate=self._hparams.embedding_dropout,
                                  training=is_train_mode(mode))
        else:
            x = tf.layers.dropout(input_embedding,
                                  rate=self._hparams.embedding_dropout,
                                  training=is_train_mode(mode))

        # Just to keep consistent with BERT, actually makes no difference
        if self._hparams.use_bert_config:
            pad_remover = None
        else:
            pad_remover = utils.transformer_utils.PadRemover(inputs_padding)

        for i in range(self._hparams.num_blocks):
            with tf.variable_scope("layer_{}".format(i)):
                multihead_attention = self.multihead_attention_list[i]
                # trivial difference between BERT and original Transformer
                if self._hparams.use_bert_config:
                    _queries_input = x
                else:
                    _queries_input = layers.layer_normalize(x)

                attention_output = multihead_attention(
                    queries=_queries_input,
                    memory=_queries_input,
                    memory_attention_bias=encoder_self_attention_bias,
                    mode=mode,
                )
                attention_output = tf.layers.dropout(
                    attention_output,
                    rate=self._hparams.residual_dropout,
                    training=is_train_mode(mode),
                )
                x = x + attention_output
                with tf.variable_scope('output'):
                    if self._hparams.use_bert_config:
                        x = layers.layer_normalize(x)
                        y = x
                    else:
                        y = layers.layer_normalize(x)
                poswise_network = self.poswise_networks[i]
                with tf.variable_scope(poswise_network.variable_scope):
                    original_shape = shape_list(y)
                    y = tf.reshape(y, [-1, self._hparams.dim])
                    if pad_remover:
                        y = tf.expand_dims(pad_remover.remove(y), axis=0)
                        # [1, batch_size*seq_length, hidden_dim]
                    layer_output = poswise_network(y, mode=mode)
                    sub_output = tf.layers.dropout(
                        layer_output,
                        rate=self._hparams.residual_dropout,
                        training=is_train_mode(mode)
                    )
                    if pad_remover:
                        sub_output = tf.reshape(pad_remover.restore(tf.squeeze(\
                            sub_output, axis=0)), original_shape \
                        )
                    else:
                        sub_output = tf.reshape(sub_output, original_shape)

                    x = x + sub_output
                    if self._hparams.use_bert_config:
                        x = layers.layer_normalize(x)

        if not self._hparams.use_bert_config:
            x = layers.layer_normalize(x)

        if not self._built:
            self._add_internal_trainable_variables()
            self._built = True

        return x
Example #10
0
    def _build(self, queries, memory, memory_attention_bias,
               cache=None, mode=None):
        """Encodes the inputs.

        Args:
            queries: A 3d tensor with shape of [batch, length_query,
                depth_query].
            memory: A 3d tensor with shape of [batch, length_key, depth_key].
            memory_attention_bias: A 3d tensor with shape of
                [batch, length_key, num_units].
            cache: Memory cache only when inferencing the sentence from sractch.
            mode (optional): A tensor taking value in
                :tf_main:`tf.estimator.ModeKeys <estimator/ModeKeys>`, including
                `TRAIN`, `EVAL` and `PREDICT`. Controls dropout mode.
                If `None` (default), :func:`texar.global_mode` is used.

        Returns:
            A Tensor of shape `[batch_size, max_time, dim]` containing the
            encoded vectors.
        """

        #pylint: disable=too-many-locals
        with tf.variable_scope(self.variable_scope):
            num_heads = self._hparams.num_heads
            num_units = self._hparams.num_units
            if num_units % num_heads:
                raise ValueError("Value depth (%d) must be divisible by "
                                 "the number of attention heads (%d)." %(\
                                 num_units, num_heads))
            if memory is None:
                # Self Attention
                Q = self.Q_dense(queries)
                K = self.K_dense(queries)
                V = self.V_dense(queries)

                if cache is not None:
                    # 'decoder self attention when dynamic decoding'
                    K = tf.concat([cache['self_keys'], K], axis=1)
                    V = tf.concat([cache['self_values'], V], axis=1)
                    cache['self_keys'] = K
                    cache['self_values'] = V
            else:
                # encoder decoder attention
                Q = self.Q_dense(queries)
                if cache is not None:
                    K, V = tf.cond(
                        tf.equal(tf.shape(cache["memory_keys"])[1], 0),
                        true_fn=lambda: \
                            [self.K_dense(memory), self.V_dense(memory)],
                        false_fn=lambda: \
                            [cache["memory_keys"], cache["memory_values"]])
                else:
                    K, V = [self.K_dense(memory), self.V_dense(memory)]

            Q_ = self._split_heads(Q)
            K_ = self._split_heads(K)
            V_ = self._split_heads(V)
            #[batch_size, num_heads, seq_length, memory_depth]
            key_depth_per_head = num_units // num_heads
            Q_ *= key_depth_per_head**-0.5

            logits = tf.matmul(Q_, K_, transpose_b=True)
            if memory_attention_bias is not None:
                logits += memory_attention_bias
            weights = tf.nn.softmax(logits, name="attention_weights")
            weights = tf.layers.dropout(weights,
                                        rate=self._hparams.dropout_rate,
                                        training=is_train_mode(mode))
            outputs = tf.matmul(weights, V_)

            outputs = self._combine_heads(outputs)
            outputs = self.O_dense(outputs)
            #(batch_size, length_query, output_dim)

        if not self._built:
            self._add_internal_trainable_variables()
            self._built = True

        return outputs
Example #11
0
    def _build(self,
               decoding_strategy="train_greedy",
               initial_state=None,
               inputs=None,
               sequence_length=None,
               embedding=None,
               start_tokens=None,
               end_token=None,
               softmax_temperature=None,
               max_decoding_length=None,
               impute_finished=False,
               output_time_major=False,
               input_time_major=False,
               helper=None,
               mode=None,
               **kwargs):
        """Performs decoding. This is a shared interface for both
        :class:`~texar.modules.BasicRNNDecoder` and
        :class:`~texar.modules.AttentionRNNDecoder`.

        The function provides **3 ways** to specify the
        decoding method, with varying flexibility:

        1. The :attr:`decoding_strategy` argument: A string taking value of:

            - **"train_greedy"**: decoding in teacher-forcing fashion \
              (i.e., feeding \
              `ground truth` to decode the next step), and each sample is \
              obtained by taking the `argmax` of the RNN output logits. \
              Arguments :attr:`(inputs, sequence_length, input_time_major)` \
              are required for this strategy, and argument :attr:`embedding` \
              is optional.
            - **"infer_greedy"**: decoding in inference fashion (i.e., feeding \
              the `generated` sample to decode the next step), and each sample\
              is obtained by taking the `argmax` of the RNN output logits.\
              Arguments :attr:`(embedding, start_tokens, end_token)` are \
              required for this strategy, and argument \
              :attr:`max_decoding_length` is optional.
            - **"infer_sample"**: decoding in inference fashion, and each
              sample is obtained by `random sampling` from the RNN output
              distribution. Arguments \
              :attr:`(embedding, start_tokens, end_token)` are \
              required for this strategy, and argument \
              :attr:`max_decoding_length` is optional.

          This argument is used only when argument :attr:`helper` is `None`.

          Example:

            .. code-block:: python

                embedder = WordEmbedder(vocab_size=data.vocab.size)
                decoder = BasicRNNDecoder(vocab_size=data.vocab.size)

                # Teacher-forcing decoding
                outputs_1, _, _ = decoder(
                    decoding_strategy='train_greedy',
                    inputs=embedder(data_batch['text_ids']),
                    sequence_length=data_batch['length']-1)

                # Random sample decoding. Gets 100 sequence samples
                outputs_2, _, sequence_length = decoder(
                    decoding_strategy='infer_sample',
                    start_tokens=[data.vocab.bos_token_id]*100,
                    end_token=data.vocab.eos.token_id,
                    embedding=embedder,
                    max_decoding_length=60)

        2. The :attr:`helper` argument: An instance of subclass of \
           :class:`texar.modules.Helper`. This
           provides a superset of decoding strategies than above, for example:

            - :class:`~texar.modules.TrainingHelper` corresponding to the \
              "train_greedy" strategy.
            - :class:`~texar.modules.GreedyEmbeddingHelper` and \
              :class:`~texar.modules.SampleEmbeddingHelper` corresponding to \
              the "infer_greedy" and "infer_sample", respectively.
            - :class:`~texar.modules.TopKSampleEmbeddingHelper` for Top-K \
              sample decoding.
            - :class:`ScheduledEmbeddingTrainingHelper` and \
              :class:`ScheduledOutputTrainingHelper` for scheduled \
              sampling.
            - :class:`~texar.modules.SoftmaxEmbeddingHelper` and \
              :class:`~texar.modules.GumbelSoftmaxEmbeddingHelper` for \
              soft decoding and gradient backpropagation.

          Helpers give the maximal flexibility of configuring the decoding\
          strategy.

          Example:

            .. code-block:: python

                embedder = WordEmbedder(vocab_size=data.vocab.size)
                decoder = BasicRNNDecoder(vocab_size=data.vocab.size)

                # Teacher-forcing decoding, same as above with
                # `decoding_strategy='train_greedy'`
                helper_1 = texar.modules.TrainingHelper(
                    inputs=embedders(data_batch['text_ids']),
                    sequence_length=data_batch['length']-1)
                outputs_1, _, _ = decoder(helper=helper_1)

                # Gumbel-softmax decoding
                helper_2 = GumbelSoftmaxEmbeddingHelper(
                    embedding=embedder,
                    start_tokens=[data.vocab.bos_token_id]*100,
                    end_token=data.vocab.eos_token_id,
                    tau=0.1)
                outputs_2, _, sequence_length = decoder(
                    max_decoding_length=60, helper=helper_2)

        3. :attr:`hparams["helper_train"]` and :attr:`hparams["helper_infer"]`:\
           Specifying the helper through hyperparameters. Train and infer \
           strategy is toggled based on :attr:`mode`. Appriopriate arguments \
           (e.g., :attr:`inputs`, :attr:`start_tokens`, etc) are selected to \
           construct the helper. Additional arguments for helper constructor \
           can be provided either through :attr:`**kwargs`, or through \
           :attr:`hparams["helper_train/infer"]["kwargs"]`.

           This means is used only when both :attr:`decoding_strategy` and \
           :attr:`helper` are `None`.

           Example:

             .. code-block:: python

                h = {
                    "helper_infer": {
                        "type": "GumbelSoftmaxEmbeddingHelper",
                        "kwargs": { "tau": 0.1 }
                    }
                }
                embedder = WordEmbedder(vocab_size=data.vocab.size)
                decoder = BasicRNNDecoder(vocab_size=data.vocab.size, hparams=h)

                # Gumbel-softmax decoding
                output, _, _ = decoder(
                    decoding_strategy=None, # Sets to None explicit
                    embedding=embedder,
                    start_tokens=[data.vocab.bos_token_id]*100,
                    end_token=data.vocab.eos_token_id,
                    max_decoding_length=60,
                    mode=tf.estimator.ModeKeys.PREDICT)
                        # PREDICT mode also shuts down dropout

        Args:
            decoding_strategy (str): A string specifying the decoding
                strategy. Different arguments are required based on the
                strategy.
                Ignored if :attr:`helper` is given.
            initial_state (optional): Initial state of decoding.
                If `None` (default), zero state is used.

            inputs (optional): Input tensors for teacher forcing decoding.
                Used when `decoding_strategy` is set to "train_greedy", or
                when `hparams`-configured helper is used.

                - If :attr:`embedding` is `None`, `inputs` is directly \
                fed to the decoder. E.g., in `"train_greedy"` strategy, \
                `inputs` must be a 3D Tensor of shape \
                `[batch_size, max_time, emb_dim]` (or \
                `[max_time, batch_size, emb_dim]` if `input_time_major`==True).
                - If `embedding` is given, `inputs` is used as index \
                to look up embeddings and feed in the decoder. \
                E.g., if `embedding` is an instance of \
                :class:`~texar.modules.WordEmbedder`, \
                then :attr:`inputs` is usually a 2D int Tensor \
                `[batch_size, max_time]` (or \
                `[max_time, batch_size]` if `input_time_major`==True) \
                containing the token indexes.
            sequence_length (optional): A 1D int Tensor containing the
                sequence length of :attr:`inputs`.
                Used when `decoding_strategy="train_greedy"` or
                `hparams`-configured helper is used.
            embedding (optional): Embedding used when:

                - "infer_greedy" or "infer_sample" `decoding_strategy` is \
                used. This can be a callable or the `params` argument for \
                :tf_main:`embedding_lookup <nn/embedding_lookup>`. \
                If a callable, it can take a vector tensor of token `ids`, \
                or take two arguments (`ids`, `times`), where `ids` \
                is a vector tensor of token ids, and `times` is a vector tensor\
                of time steps (i.e., position ids). The latter case can be used\
                when attr:`embedding` is a combination of word embedding and\
                position embedding. `embedding` is required in this case.
                - "train_greedy" `decoding_strategy` is used.\
                This can be a callable or the `params` argument for \
                :tf_main:`embedding_lookup <nn/embedding_lookup>`. \
                If a callable, it can take :attr:`inputs` and returns \
                the input embedding. `embedding` is optional in this case.
            start_tokens (optional): A int Tensor of shape `[batch_size]`,
                the start tokens. Used when `decoding_strategy="infer_greedy"`
                or `"infer_sample"`, or when the helper specified in `hparams`
                is used.

                Example:

                    .. code-block:: python

                        data = tx.data.MonoTextData(hparams)
                        iterator = DataIterator(data)
                        batch = iterator.get_next()

                        bos_token_id = data.vocab.bos_token_id
                        start_tokens=tf.ones_like(batch['length'])*bos_token_id

            end_token (optional): A int 0D Tensor, the token that marks end
                of decoding.
                Used when `decoding_strategy="infer_greedy"` or
                `"infer_sample"`, or when the helper specified in `hparams`
                is used.
            softmax_temperature (optional): A float 0D Tensor, value to divide
                the logits by before computing the softmax. Larger values
                (above 1.0) result in more random samples. Must > 0. If `None`,
                1.0 is used.
                Used when `decoding_strategy="infer_sample"`.
            max_decoding_length: A int scalar Tensor indicating the maximum
                allowed number of decoding steps. If `None` (default), either
                `hparams["max_decoding_length_train"]` or
                `hparams["max_decoding_length_infer"]` is used
                according to :attr:`mode`.
            impute_finished (bool): If `True`, then states for batch
                entries which are marked as finished get copied through and
                the corresponding outputs get zeroed out.  This causes some
                slowdown at each time step, but ensures that the final state
                and outputs have the correct values and that backprop ignores
                time steps that were marked as finished.
            output_time_major (bool): If `True`, outputs are returned as
                time major tensors. If `False` (default), outputs are returned
                as batch major tensors.
            input_time_major (optional): Whether the :attr:`inputs` tensor is
                time major.
                Used when `decoding_strategy="train_greedy"` or
                `hparams`-configured helper is used.
            helper (optional): An instance of
                :class:`texar.modules.Helper`
                that defines the decoding strategy. If given,
                `decoding_strategy`
                and helper configs in :attr:`hparams` are ignored.
            mode (str, optional): A string taking value in
                :tf_main:`tf.estimator.ModeKeys <estimator/ModeKeys>`. If
                `TRAIN`, training related hyperparameters are used (e.g.,
                `hparams['max_decoding_length_train']`), otherwise,
                inference related hyperparameters are used (e.g.,
                `hparams['max_decoding_length_infer']`).
                If `None` (default), `TRAIN` mode is used.
            **kwargs: Other keyword arguments for constructing helpers
                defined by `hparams["helper_trainn"]` or
                `hparams["helper_infer"]`.

        Returns:
            `(outputs, final_state, sequence_lengths)`, where

            - **`outputs`**: an object containing the decoder output on all \
            time steps.
            - **`final_state`**: is the cell state of the final time step.
            - **`sequence_lengths`**: is an int Tensor of shape `[batch_size]` \
            containing the length of each sample.
        """
        # Helper
        if helper is not None:
            pass
        elif decoding_strategy is not None:
            if decoding_strategy == "train_greedy":
                helper = rnn_decoder_helpers._get_training_helper(
                    inputs, sequence_length, embedding, input_time_major)
            elif decoding_strategy == "infer_greedy":
                helper = tx_helper.GreedyEmbeddingHelper(
                    embedding, start_tokens, end_token)
            elif decoding_strategy == "infer_sample":
                helper = tx_helper.SampleEmbeddingHelper(
                    embedding, start_tokens, end_token, softmax_temperature)
            else:
                raise ValueError(
                    "Unknown decoding strategy: {}".format(decoding_strategy))
        else:
            if is_train_mode_py(mode):
                kwargs_ = copy.copy(self._hparams.helper_train.kwargs.todict())
                helper_type = self._hparams.helper_train.type
            else:
                kwargs_ = copy.copy(self._hparams.helper_infer.kwargs.todict())
                helper_type = self._hparams.helper_infer.type
            kwargs_.update({
                "inputs": inputs,
                "sequence_length": sequence_length,
                "time_major": input_time_major,
                "embedding": embedding,
                "start_tokens": start_tokens,
                "end_token": end_token,
                "softmax_temperature": softmax_temperature
            })
            kwargs_.update(kwargs)
            helper = rnn_decoder_helpers.get_helper(helper_type, **kwargs_)
        self._helper = helper

        # Initial state
        if initial_state is not None:
            self._initial_state = initial_state
        else:
            self._initial_state = self.zero_state(batch_size=self.batch_size,
                                                  dtype=tf.float32)

        # Maximum decoding length
        max_l = max_decoding_length
        if max_l is None:
            max_l_train = self._hparams.max_decoding_length_train
            if max_l_train is None:
                max_l_train = utils.MAX_SEQ_LENGTH
            max_l_infer = self._hparams.max_decoding_length_infer
            if max_l_infer is None:
                max_l_infer = utils.MAX_SEQ_LENGTH
            max_l = tf.cond(is_train_mode(mode), lambda: max_l_train,
                            lambda: max_l_infer)
        self.max_decoding_length = max_l
        # Decode
        outputs, final_state, sequence_lengths = dynamic_decode(
            decoder=self,
            impute_finished=impute_finished,
            maximum_iterations=max_l,
            output_time_major=output_time_major)

        if not self._built:
            self._add_internal_trainable_variables()
            # Add trainable variables of `self._cell` which may be
            # constructed externally.
            self._add_trainable_variable(
                layers.get_rnn_cell_trainable_variables(self._cell))
            if isinstance(self._output_layer, tf.layers.Layer):
                self._add_trainable_variable(
                    self._output_layer.trainable_variables)
            # Add trainable variables of `self._beam_search_rnn_cell` which
            # may already be constructed and used.
            if self._beam_search_cell is not None:
                self._add_trainable_variable(
                    self._beam_search_cell.trainable_variables)

            self._built = True

        return outputs, final_state, sequence_lengths
    def _build(self, inputs, sequence_length, mode=None):
        """Encodes the inputs.

        Args:
            inputs: A 3D Tensor of shape `[batch_size, max_time, dim]`,
                containing the word embeddings of input sequences. Note that
                the embedding dimension `dim` must equal "dim" in
                :attr:`hparams`.
            sequence_length: A 1D Tensor of shape `[batch_size]`. Input tokens
                beyond respective sequence lengths are masked out
                automatically.
            mode (optional): A tensor taking value in
                :tf_main:`tf.estimator.ModeKeys <estimator/ModeKeys>`,
                including `TRAIN`, `EVAL`, and `PREDICT`. Used to toggle
                dropout.
                If `None` (default), :func:`texar.global_mode` is used.

        Returns:
            A Tensor of shape `[batch_size, max_time, dim]` containing the
            encoded vectors.
        """
        # Multiply input embedding with the sqrt of its dimension for
        # normalization
        inputs = inputs * self._hparams.dim**0.5

        inputs = mask_sequences(inputs, sequence_length, tensor_rank=3)

        _, lengths, _ = shape_list(inputs)

        inputs_padding = 1 - tf.sequence_mask(
            sequence_length, tf.shape(inputs)[1], dtype=tf.float32)
        ignore_padding = attn.attention_bias_ignore_padding(inputs_padding)
        encoder_self_attention_bias = ignore_padding

        pos_embeds = self.position_embedder(lengths, self._hparams.dim)
        input_embedding = inputs + pos_embeds

        x = tf.layers.dropout(input_embedding,
                              rate=self._hparams.embedding_dropout,
                              training=is_train_mode(mode))
        pad_remover = utils.transformer_utils.PadRemover(inputs_padding)

        for i in range(self._hparams.num_blocks):
            with tf.variable_scope("layer_{}".format(i)):
                with tf.variable_scope('self_attention'):
                    selfatt_output = attn.multihead_attention(
                        queries=layers.layer_normalize(x),
                        memory=None,
                        memory_attention_bias=encoder_self_attention_bias,
                        num_heads=self._hparams.num_heads,
                        dropout_rate=self._hparams.attention_dropout,
                        num_units=self._hparams.dim,
                        scope='multihead_attention')
                    x = x + tf.layers.dropout(
                        selfatt_output,
                        rate=self._hparams.residual_dropout,
                        training=is_train_mode(mode),
                    )

                poswise_network = FeedForwardNetwork(
                    hparams=self._hparams['poswise_feedforward'])
                with tf.variable_scope(poswise_network.variable_scope):
                    y = layers.layer_normalize(x)
                    original_shape = shape_list(y)
                    y = tf.reshape(y, [-1, self._hparams.dim])
                    y = tf.expand_dims(pad_remover.remove(y), axis=0)
                    # [1, batch_size*seq_length, hidden_dim]
                    sub_output = tf.layers.dropout(
                        poswise_network(y),
                        rate=self._hparams.residual_dropout,
                        training=is_train_mode(mode))
                    sub_output = tf.reshape(pad_remover.restore(tf.squeeze(\
                        sub_output, axis=0)), original_shape \
                    )
                    x = x + sub_output

        encoder_output = layers.layer_normalize(x)

        if not self._built:
            self._add_internal_trainable_variables()
            self._built = True

        return encoder_output
    def _build(self,
               queries,
               memory,
               memory_attention_bias,
               cache=None,
               mode=None):
        """Encodes the inputs.

        Args:
            queries: A 3d tensor with shape of [batch, length_query,
                depth_query].
            memory: A 3d tensor with shape of [batch, length_key, depth_key].
            memory_attention_bias: A 3d tensor with shape of
                [batch, length_key, num_units].
            cache: Memory cache only when inferencing the sentence from sractch.
            mode (optional): A tensor taking value in
                :tf_main:`tf.estimator.ModeKeys <estimator/ModeKeys>`, including
                `TRAIN`, `EVAL` and `PREDICT`. Controls dropout mode.
                If `None` (default), :func:`texar.global_mode` is used.

        Returns:
            A Tensor of shape `[batch_size, max_time, dim]` containing the
            encoded vectors.
        """

        with tf.variable_scope(self.variable_scope):
            num_heads = self._hparams.num_heads
            num_units = self._hparams.num_units
            if num_units % num_heads:
                raise ValueError("Value depth (%d) must be divisible by "
                                 "the number of attention heads (%d)." %(\
                                 num_units, num_heads))

            def _update_and_return(layer, key):
                if memory is None:
                    # Self Attention
                    out = layer(queries)

                    if cache is not None:
                        # 'decoder self attention when dynamic decoding'
                        key = 'self_{}'.format(key)
                        res = cache[key]
                        if isinstance(res, tf.TensorArray):
                            # inference-like decoding
                            # TODO(zhiting): This writing op may cause a bug
                            # on CPU--it looks the two TensorArray
                            # cache['self_keys'] and cache['self_values']
                            # will mix up starting from certain step, causing
                            # shape mismatch. This op looks fine on GPU.
                            res = res.write(res.size(),
                                            tf.squeeze(out, axis=[1]))
                            out = transpose_batch_time(res.stack())
                        else:
                            # normal decoding
                            res = tf.concat([res, out], axis=1)
                            out = res
                        cache[key] = res

                else:
                    # encoder decoder attention
                    if cache is not None:
                        key = 'memory_{}'.format(key)
                        res = cache[key]
                        if isinstance(res, tf.TensorArray):
                            # inference-like decoding
                            size = res.size()
                            false_fn = lambda: transpose_batch_time(res.stack(
                            ))
                        else:
                            # normal decoding
                            size = tf.shape(res)[1]
                            false_fn = lambda: res
                        out = tf.cond(tf.equal(size, 0),
                                      true_fn=lambda: layer(memory),
                                      false_fn=false_fn)
                    else:
                        out = layer(memory)

                return out

            Q = self.Q_dense(queries)
            K = _update_and_return(self.K_dense, 'keys')
            V = _update_and_return(self.V_dense, 'values')

            Q_ = self._split_heads(Q)
            K_ = self._split_heads(K)
            V_ = self._split_heads(V)
            #[batch_size, num_heads, seq_length, memory_depth]
            key_depth_per_head = num_units // num_heads
            Q_ *= key_depth_per_head**-0.5

            logits = tf.matmul(Q_, K_, transpose_b=True)
            if memory_attention_bias is not None:
                logits += memory_attention_bias
            weights = tf.nn.softmax(logits, name="attention_weights")
            weights = tf.layers.dropout(weights,
                                        rate=self._hparams.dropout_rate,
                                        training=is_train_mode(mode))
            outputs = tf.matmul(weights, V_)

            outputs = self._combine_heads(outputs)
            outputs = self.O_dense(outputs)
            #(batch_size, length_query, output_dim)

        if not self._built:
            self._add_internal_trainable_variables()
            self._built = True

        return outputs
Example #14
0
    def _build(self,
               ids=None,
               soft_ids=None,
               stop_gradient=False,
               mode=None,
               **kwargs):
        """Embeds (soft) ids.

        Either :attr:`ids` or :attr:`soft_ids` must be given, and they
        must not be given at the same time.

        Args:
            ids (optional): An integer tensor containing the ids to embed.
            soft_ids (optional): A tensor of weights (probabilities) used to
                mix the embedding vectors.
            stop_gradient (bool): Whether to stop gradient back-propagation
                to the embedding tensor. This can be used when, e.g., updating
                only `soft_ids` while keeping the embedding tensor fixed.
                Default to `False`.
            mode (optional): A tensor taking value in
                :tf_main:`tf.estimator.ModeKeys <estimator/ModeKeys>`, including
                `TRAIN`, `EVAL`, and `PREDICT`. If `None`, dropout is
                controlled by :func:`texar.global_mode`.
            kwargs: Additional keyword arguments for
                :tf_main:`tf.nn.embedding_lookup <nn/embedding_lookup>` besides
                :attr:`params` and :attr:`ids`.

        Returns:
            If :attr:`ids` is given, returns a Tensor of shape
            `shape(ids) + embedding-dim`. For example,
            if `shape(ids) = [batch_size, max_time]`
            and `shape(embedding) = [vocab_size, emb_dim]`, then the return
            tensor has shape `[batch_size, max_time, emb_dim]`.

            If :attr:`soft_ids` is given, returns a Tensor of shape
            `shape(soft_ids)[:-1] + embdding-dim`. For example,
            if `shape(soft_ids) = [batch_size, max_time, vocab_size]`
            and `shape(embedding) = [vocab_size, emb_dim]`, then the return
            tensor has shape `[batch_size, max_time, emb_dim]`.
        """
        if ids is not None:
            if soft_ids is not None:
                raise ValueError(
                    'Must not specify `ids` and `soft_ids` at the same time.')
            ids_rank = get_rank(ids)
        elif soft_ids is not None:
            ids_rank = get_rank(soft_ids) - 1
        else:
            raise ValueError('Either `ids` or `soft_ids` must be given.')

        embedding = self._embedding
        if stop_gradient:
            embedding = tf.stop_gradient(embedding)

        is_training = is_train_mode(mode)
        if self._hparams.dropout_strategy == 'item_type':
            dropout_layer = self._get_dropout_layer(self._hparams)
            if dropout_layer:
                embedding = dropout_layer.apply(inputs=embedding,
                                                training=is_training)

        if ids is not None:
            outputs = tf.nn.embedding_lookup(embedding, ids, **kwargs)
        else:
            outputs = embedder_utils.soft_embedding_lookup(embedding, soft_ids)

        if self._hparams.dropout_strategy != 'item_type':
            dropout_layer = self._get_dropout_layer(self._hparams,
                                                    ids_rank=ids_rank,
                                                    dropout_input=outputs)
            if dropout_layer:
                outputs = dropout_layer.apply(inputs=outputs,
                                              training=is_training)

        return outputs
    def _self_attention_stack(self,
                              inputs,
                              memory,
                              decoder_self_attention_bias=None,
                              memory_attention_bias=None,
                              cache=None,
                              mode=None):
        """Stacked multihead attention module.
        """
        def _layer_norm(x, scope):
            return layers.layer_normalize(x, reuse=tf.AUTO_REUSE, scope=scope)

        inputs = tf.layers.dropout(inputs,
                                   rate=self._hparams.embedding_dropout,
                                   training=is_train_mode(mode))
        if cache is not None:
            if memory is not None:
                memory_attention_bias = \
                    cache['memory_attention_bias']
        else:
            assert decoder_self_attention_bias is not None

        # self.adj_masks is set at the beginning of _build()
        adj_masks = self.adj_masks

        x = inputs
        for i in range(self._hparams.num_blocks):
            layer_name = 'layer_{}'.format(i)
            layer_cache = cache[layer_name] if cache is not None else None
            with tf.variable_scope(layer_name) as layer_scope:
                with tf.variable_scope("self_attention"):
                    graph_multihead_attention = \
                        self.multihead_attentions['self_att'][i]
                    selfatt_output = graph_multihead_attention(
                        queries=_layer_norm(x, layer_scope),
                        memory=None,
                        adj_masks=adj_masks,
                        memory_attention_bias=decoder_self_attention_bias,
                        cache=layer_cache,
                        mode=mode,
                    )
                    x = x + tf.layers.dropout(
                        selfatt_output,
                        rate=self._hparams.residual_dropout,
                        training=is_train_mode(mode),
                    )
                if memory is not None:
                    with tf.variable_scope('encdec_attention') as \
                            encdec_attention_scope:
                        graph_multihead_attention = \
                            self.multihead_attentions['encdec_att'][i]
                        encdec_output = graph_multihead_attention(
                            queries=_layer_norm(x, encdec_attention_scope),
                            memory=memory,
                            adj_masks=adj_masks,
                            memory_attention_bias=memory_attention_bias,
                            mode=mode,
                        )
                        x = x + tf.layers.dropout(
                            encdec_output,
                            rate=self._hparams.residual_dropout,
                            training=is_train_mode(mode))
                poswise_network = self.poswise_networks[i]
                with tf.variable_scope('past_poswise_ln') as \
                        past_poswise_ln_scope:
                    sub_output = tf.layers.dropout(
                        poswise_network(_layer_norm(x, past_poswise_ln_scope)),
                        rate=self._hparams.residual_dropout,
                        training=is_train_mode(mode),
                    )
                    x = x + sub_output

        return _layer_norm(x, scope=self.variable_scope)
    def _build(self,
               inputs,
               memory,
               sequence_length,
               memory_sequence_length,
               adjs,
               encoder_output,
               mode=None):
        """Encodes the inputs.

        Args:
            inputs: A 3D Tensor of shape `[batch_size, max_time, dim]`,
                containing the embedding of input sequences. Note that
                the embedding dimension `dim` must equal "dim" in
                :attr:`hparams`. The input embedding is typically an aggregation
                of word embedding and position embedding.
            memory: A 3D Tensor of shape `[batch_size, memory_max_time, dim]`, 
                containing the embedding of memory sequences. Note that
                the embedding dimension `dim` must equal "dim" in
                :attr:`hparams`. The input embedding is typically an aggregation
                of word embedding and position embedding.
            sequence_length: A 1D Tensor of shape `[batch_size]`. Input tokens
                beyond respective sequence lengths are masked out
                automatically.
            sequence_length: A 1D Tensor of shape `[batch_size]`. Memory tokens
                beyond respective sequence lengths are masked out
                automatically.
            adjs: A 3D Tensor of shape `[batch_size, max_time, max_time]`,
                containing the adjacency matrices of input sequences
            encoder_output: bool. True: return encoder-like embeddings. False: return CrossGraphTransformerDecoderOutput. 
            mode (optional): A tensor taking value in
                :tf_main:`tf.estimator.ModeKeys <estimator/ModeKeys>`,
                including `TRAIN`, `EVAL`, and `PREDICT`. Used to toggle
                dropout.
                If `None` (default), :func:`texar.global_mode` is used.

        Returns:
            A Tensor of shape `[batch_size, max_time, dim]` containing the
            encoded vectors.
        """
        # Get adjacency masks from adjs
        adj_masks = 1 - tf.cast(tf.equal(adjs, 0), dtype=tf.float32)

        # Multiply input embedding with the sqrt of its dimension for
        # normalization
        inputs_padding = 1 - tf.sequence_mask(
            sequence_length, tf.shape(inputs)[1], dtype=tf.float32)
        if self._hparams.use_bert_config:
            ignore_padding = attn.attention_bias_ignore_padding(
                inputs_padding, bias_value=-1e4)
        else:
            ignore_padding = attn.attention_bias_ignore_padding(inputs_padding)
        encoder_self_attention_bias = ignore_padding

        input_embedding = inputs  # shape (batch_size, max_time, dim)

        if self._hparams.use_bert_config:
            x = layers.layer_normalize(input_embedding)
            x = tf.layers.dropout(x,
                                  rate=self._hparams.embedding_dropout,
                                  training=is_train_mode(mode))
        else:
            x = tf.layers.dropout(input_embedding,
                                  rate=self._hparams.embedding_dropout,
                                  training=is_train_mode(mode))

        # Just to keep consistent with BERT, actually makes no difference
        if self._hparams.use_bert_config:
            pad_remover = None
        else:
            pad_remover = utils.transformer_utils.PadRemover(inputs_padding)

        for i in range(self._hparams.num_blocks):
            with tf.variable_scope("layer_{}".format(i)):
                graph_multihead_attention = self.graph_multihead_attention_list[
                    i]

                # trivial difference between BERT and original Transformer
                if self._hparams.use_bert_config:
                    _queries_input = x
                else:
                    _queries_input = layers.layer_normalize(x)

                attention_output = graph_multihead_attention(
                    queries=_queries_input,
                    memory=memory,
                    adj_masks=adj_masks,
                    memory_attention_bias=encoder_self_attention_bias,
                    mode=mode,
                )
                attention_output = tf.layers.dropout(
                    attention_output,
                    rate=self._hparams.residual_dropout,
                    training=is_train_mode(mode),
                )
                # attention_output: weighted sum of V of memory with weights determined by querying keys of memory
                x = x + attention_output
                with tf.variable_scope('output'):
                    if self._hparams.use_bert_config:
                        x = layers.layer_normalize(x)
                        y = x
                    else:
                        y = layers.layer_normalize(x)

                poswise_network = self.poswise_networks[i]
                with tf.variable_scope(poswise_network.variable_scope):
                    original_shape = shape_list(y)
                    y = tf.reshape(y, [-1, self._hparams.dim])
                    if pad_remover:
                        y = tf.expand_dims(pad_remover.remove(y), axis=0)
                        # [1, batch_size*seq_length, hidden_dim]
                    layer_output = poswise_network(y, mode=mode)
                    sub_output = tf.layers.dropout(
                        layer_output,
                        rate=self._hparams.residual_dropout,
                        training=is_train_mode(mode))
                    if pad_remover:
                        sub_output = tf.reshape(pad_remover.restore(tf.squeeze(\
                            sub_output, axis=0)), original_shape \
                        )
                    else:
                        sub_output = tf.reshape(sub_output, original_shape)

                    x = x + sub_output
                    if self._hparams.use_bert_config:
                        x = layers.layer_normalize(x)

        if not self._hparams.use_bert_config:
            x = layers.layer_normalize(x)

        if not self._built:
            self._add_internal_trainable_variables()
            self._built = True

        if encoder_output:
            return x

        logits = self._output_layer(x)
        sample_ids = tf.to_int32(tf.argmax(logits, axis=-1))
        probs = ''
        # probs = GumbelSoftmax(self._tau, logits=logits).sample()
        # probs = tf.nn.softmax(logits / self._tau) # vanilla softmax

        rets = CrossGraphTransformerFixedLengthDecoderOutput(
            logits=logits, sample_id=sample_ids, probs=probs)

        return rets
    def _build(self,
               decoding_strategy="train_greedy",
               initial_state=None,
               inputs=None,
               memory=None,
               sequence_length=None,
               embedding=None,
               start_tokens=None,
               end_token=None,
               softmax_temperature=None,
               max_decoding_length=None,
               impute_finished=False,
               output_time_major=False,
               input_time_major=False,
               helper=None,
               mode=None,
               **kwargs):
        # Memory
        for _mechanism in self._cell._attention_mechanisms:
            _mechanism.initialize_memory(memory)
        # Helper
        if helper is not None:
            pass
        elif decoding_strategy is not None:
            if decoding_strategy == "train_greedy":
                helper = rnn_decoder_helpers._get_training_helper(
                    inputs, sequence_length, embedding, input_time_major)
            elif decoding_strategy == "infer_greedy":
                helper = tx_helper.GreedyEmbeddingHelper(
                    embedding, start_tokens, end_token)
            elif decoding_strategy == "infer_sample":
                helper = tx_helper.SampleEmbeddingHelper(
                    embedding, start_tokens, end_token, softmax_temperature)
            else:
                raise ValueError(
                    "Unknown decoding strategy: {}".format(decoding_strategy))
        else:
            if is_train_mode_py(mode):
                kwargs_ = copy.copy(self._hparams.helper_train.kwargs.todict())
                helper_type = self._hparams.helper_train.type
            else:
                kwargs_ = copy.copy(self._hparams.helper_infer.kwargs.todict())
                helper_type = self._hparams.helper_infer.type
            kwargs_.update({
                "inputs": inputs,
                "sequence_length": sequence_length,
                "time_major": input_time_major,
                "embedding": embedding,
                "start_tokens": start_tokens,
                "end_token": end_token,
                "softmax_temperature": softmax_temperature
            })
            kwargs_.update(kwargs)
            helper = rnn_decoder_helpers.get_helper(helper_type, **kwargs_)
        self._helper = helper

        # Initial state
        if initial_state is not None:
            self._initial_state = initial_state
        else:
            self._initial_state = self.zero_state(batch_size=self.batch_size,
                                                  dtype=tf.float32)

        # Maximum decoding length
        max_l = max_decoding_length
        if max_l is None:
            max_l_train = self._hparams.max_decoding_length_train
            if max_l_train is None:
                max_l_train = utils.MAX_SEQ_LENGTH
            max_l_infer = self._hparams.max_decoding_length_infer
            if max_l_infer is None:
                max_l_infer = utils.MAX_SEQ_LENGTH
            max_l = tf.cond(is_train_mode(mode), lambda: max_l_train,
                            lambda: max_l_infer)
        self.max_decoding_length = max_l
        # Decode
        outputs, final_state, sequence_lengths = dynamic_decode(
            decoder=self,
            impute_finished=impute_finished,
            maximum_iterations=max_l,
            output_time_major=output_time_major)

        if not self._built:
            self._add_internal_trainable_variables()
            # Add trainable variables of `self._cell` which may be
            # constructed externally.
            self._add_trainable_variable(
                layers.get_rnn_cell_trainable_variables(self._cell))
            if isinstance(self._output_layer, tf.layers.Layer):
                self._add_trainable_variable(
                    self._output_layer.trainable_variables)
            # Add trainable variables of `self._beam_search_rnn_cell` which
            # may already be constructed and used.
            if self._beam_search_cell is not None:
                self._add_trainable_variable(
                    self._beam_search_cell.trainable_variables)

            self._built = True

        return outputs, final_state, sequence_lengths