Esempio n. 1
0
    def call(self, inputs, inputs_padding, is_training=True):
        """ Encodes the inputs.

        Args:
            inputs: The embedded input, a float tensor with shape
                [batch_size, max_length, embedding_dim].
            inputs_padding: A float tensor with shape [batch_size, max_length],
                indicating the padding positions, where 1.0 for padding and
                0.0 for non-padding.
            is_training: A bool, whether in training mode or not.

        Returns:
            The encoded output with shape [batch_size, max_length, hidden_size]
        """
        # [batch_size, max_length], FLOAT_MIN for padding, 0.0 for non-padding
        all_layers = []
        self_attention_bias = layer_utils.input_padding_to_bias(inputs_padding)
        if self.get_config()["attention_monotonic"]:
            self_attention_bias = tf.minimum(
                tf.expand_dims(tf.expand_dims(self_attention_bias, axis=1),
                               axis=1),
                layer_utils.lower_triangle_attention_bias(tf.shape(inputs)[1]))
        x = inputs
        if is_training:
            x = tf.nn.dropout(
                x, rate=self.get_config()["layer_postprocess_dropout_rate"])
        for idx, layer in enumerate(self._stacking_layers):
            x = layer(x, self_attention_bias, is_training=is_training)
            all_layers.append(x)
        if self.get_config()["post_normalize"]:
            if self._return_all_layers:
                return all_layers
            return x
        outputs = self.quant(self._output_norm_layer(x), name="output_ln")
        return outputs
Esempio n. 2
0
    def call(self, inputs, inputs_padding, is_training=True):
        """ Encodes the inputs.

        Args:
            inputs: The embedded input, a float tensor with shape
                [batch_size, max_length, embedding_dim].
            inputs_padding: A float tensor with shape [batch_size, max_length],
                indicating the padding positions, where 1.0 for padding and
                0.0 for non-padding.
            is_training: A bool, whether in training mode or not.

        Returns:
            The encoded output with shape [batch_size, max_length, hidden_size]
        """
        # [batch_size, max_length], FLOAT_MIN for padding, 0.0 for non-padding
        self_attention_bias = layer_utils.input_padding_to_bias(inputs_padding)
        x = inputs
        if is_training:
            x = tf.nn.dropout(
                x, rate=self.get_config()["layer_postprocess_dropout_rate"])
        for idx, layer in enumerate(self._stacking_layers):
            self_attention_layer = layer[0]
            ffn_layer = layer[1]
            with tf.name_scope("layer_{}".format(idx)):
                # self attention layer
                x = self_attention_layer(
                    x,  # x as query
                    bias=self_attention_bias,
                    is_training=is_training)
                # ffn
                x = ffn_layer(x, is_training=is_training)

        return self._output_norm_layer(x)
Esempio n. 3
0
    def create_decoding_internal_cache(self,
                                       encoder_outputs,
                                       encoder_inputs_padding,
                                       is_inference=False,
                                       decode_padded_length=None):
        """ Creates internal cache for decoding.

        Args:
            encoder_outputs: The output tensor from encoder
                with shape [batch_size, max_input_length, hidden_size].
            encoder_inputs_padding: A float tensor with shape [batch_size, max_length],
                indicating the padding positions of `encoder_output`, where 1.0 for
                padding and 0.0 for non-padding.
            is_inference: A boolean scalar, whether in inference mode or not.
            decode_padded_length: The maximum decoding length when inference, for creating
                static-shape cache.

        Returns:
            `cache`, a dictionary containing static(e.g. encoder hidden states
            for attention) and dynamic(e.g. transformer decoding cache) tensors used
            during decoding and will be passed to `call()`. Note that, the dynamic
            tensors must store in cache["decoding_states"] for beam search use.
        """
        # [batch_size, max_length], FLOAT_MIN for padding, 0.0 for non-padding
        if is_inference:
            params = self.get_config()
            decoding_states = {}
            batch_size = tf.shape(encoder_outputs)[0]
            num_heads = params["num_attention_heads"]
            num_units_per_head = params["hidden_size"] // num_heads
            # initialize decoder self attention keys/values
            for lid in range(params["num_layers"]):
                # Ensure shape invariance for tf.while_loop.
                decoding_states["layer_{}".format(lid)] = {
                    "self_attention": {
                        "keys":
                        tf.zeros([
                            batch_size, decode_padded_length or 0, num_heads,
                            num_units_per_head
                        ],
                                 dtype=compat.CUSTOM_GLOBAL_FLOATX),
                        "values":
                        tf.zeros([
                            batch_size, decode_padded_length or 0, num_heads,
                            num_units_per_head
                        ],
                                 dtype=compat.CUSTOM_GLOBAL_FLOATX)
                    },
                }
        else:
            decoding_states = None
        cache = dict(decoding_states=decoding_states)
        if self._with_encoder_decoder_attention:
            cache["memory"] = encoder_outputs
            cache["memory_bias"] = layer_utils.input_padding_to_bias(
                encoder_inputs_padding)
        return cache
    def create_decoding_internal_cache(self,
                                       encoder_outputs,
                                       encoder_inputs_padding,
                                       is_inference=False,
                                       decode_padded_length=None):
        """ Creates internal cache for decoding.

        Args:
            encoder_outputs: The output tensor from encoder
                with shape [batch_size, max_input_length, hidden_size].
            encoder_inputs_padding: A float tensor with shape [batch_size, max_length],
                indicating the padding positions of `encoder_output`, where 1.0 for
                padding and 0.0 for non-padding.
            is_inference: A boolean scalar, whether in inference mode or not.
            decode_padded_length: The maximum decoding length when inference, for creating
                static-shape cache.

        Returns:
            `cache`, a dictionary containing static(e.g. encoder hidden states
            for attention) and dynamic(e.g. transformer decoding cache) tensors used
            during decoding and will be passed to `call()`. Note that, the dynamic
            tensors must store in cache["decoding_states"] for beam search use.
        """
        # [batch_size, max_length], FLOAT_MIN for padding, 0.0 for non-padding
        enc_dec_attention_bias = layer_utils.input_padding_to_bias(
            encoder_inputs_padding)
        if is_inference:
            params = self.get_config()
            decoding_states = {}
            batch_size = tf.shape(encoder_outputs)[0]
            # initialize decoder conv hidden states
            for lid in range(params["num_layers"]):
                # Ensure shape invariance for tf.while_loop.
                if decode_padded_length is None:
                    init_len = params["conv_kernel_size_list"][lid] - 1
                else:
                    init_len = params["conv_kernel_size_list"][
                        lid] - 1 + decode_padded_length
                decoding_states["layer_{}".format(lid)] = {
                    "light_conv": {
                        "conv":
                        tf.zeros(
                            [batch_size, init_len, params["conv_hidden_size"]],
                            dtype=compat.CUSTOM_GLOBAL_FLOATX)
                    },
                }
        else:
            decoding_states = None
        cache = dict(decoding_states=decoding_states,
                     memory=encoder_outputs,
                     memory_bias=enc_dec_attention_bias)
        return cache
Esempio n. 5
0
    def create_decoding_internal_cache(self,
                                       encoder_outputs,
                                       encoder_inputs_padding,
                                       is_inference=False,
                                       decode_padded_length=None):
        """ Creates internal cache for decoding.

        Args:
            encoder_outputs: The output tensor from encoder
                with shape [batch_size, max_input_length, hidden_size].
            encoder_inputs_padding: A float tensor with shape [batch_size, max_length],
                indicating the padding positions of `encoder_output`, where 1.0 for
                padding and 0.0 for non-padding.
            is_inference: A boolean scalar, whether in inference mode or not.
            decode_padded_length: The maximum decoding length when inference, for creating
                static-shape cache.

        Returns:
            `cache`, a dictionary containing static(e.g. encoder hidden states
            for attention) and dynamic(e.g. transformer decoding cache) tensors used
            during decoding and will be passed to `call()`. Note that, the dynamic
            tensors must store in cache["decoding_states"] for beam search use.
        """
        # [batch_size, max_length], FLOAT_MIN for padding, 0.0 for non-padding
        if is_inference:
            decoding_states = {}
            batch_size = tf.shape(encoder_outputs)[0]
            # initialize decoder self attention keys/values
            for lid, layer in enumerate(self._stacking_layers):
                # Ensure shape invariance for tf.while_loop.
                decoding_states[
                    f"layer_{lid}"] = layer.create_decoding_internal_cache(
                        decode_padded_length)
            decoding_states = tf.nest.map_structure(
                lambda ts: tile_tensor(ts, batch_size, axis=0),
                decoding_states)
            for lid, layer in enumerate(self._stacking_layers):
                decoding_states[f"layer_{lid}"].update(
                    layer.memorize_memory(encoder_outputs))
        else:
            decoding_states = None
        cache = dict(decoding_states=decoding_states)
        if encoder_inputs_padding is not None:
            cache["memory"] = encoder_outputs
            cache["memory_bias"] = layer_utils.input_padding_to_bias(
                encoder_inputs_padding)
        return cache