Exemplo n.º 1
0
    def call(self, inputs, inputs_padding, is_training=True):
        """ Encodes the inputs.

        Args:
            inputs: The embedded input, a float tensor with shape
                [batch_size, max_length, embedding_dim].
            inputs_padding: A float tensor with shape [batch_size, max_length],
                indicating the padding positions, where 1.0 for padding and
                0.0 for non-padding.
            is_training: A bool, whether in training mode or not.

        Returns:
            The encoded output with shape [batch_size, max_length, hidden_size]
        """
        # [batch_size, max_length], FLOAT_MIN for padding, 0.0 for non-padding
        all_layers = []
        self_attention_bias = layer_utils.input_padding_to_bias(inputs_padding)
        if self.get_config()["attention_monotonic"]:
            self_attention_bias = tf.minimum(
                tf.expand_dims(tf.expand_dims(self_attention_bias, axis=1),
                               axis=1),
                layer_utils.lower_triangle_attention_bias(tf.shape(inputs)[1]))
        x = inputs
        if is_training:
            x = tf.nn.dropout(
                x, rate=self.get_config()["layer_postprocess_dropout_rate"])
        for idx, layer in enumerate(self._stacking_layers):
            x = layer(x, self_attention_bias, is_training=is_training)
            all_layers.append(x)
        if self.get_config()["post_normalize"]:
            if self._return_all_layers:
                return all_layers
            return x
        outputs = self.quant(self._output_norm_layer(x), name="output_ln")
        return outputs
Exemplo n.º 2
0
    def call(self, decoder_inputs, cache, is_training=True, decode_loop_step=None):
        """ Encodes the inputs.

        Args:
            decoder_inputs: The embedded decoder input, a float tensor with shape
                [batch_size, max_target_length, embedding_dim] or
                [batch_size, embedding_dim] for one decoding step.
            cache: A dictionary, generated from self.create_decoding_internal_cache.
            is_training: A bool, whether in training mode or not.
            decode_loop_step: An integer, step number of the decoding loop. Used only
                for autoregressive inference with static-shape cache.

        Returns:
            The decoder output with shape [batch_size, max_length, hidden_size]
            when `decoder_inputs` is a 3-d tensor or with shape
            [batch_size, hidden_size] when `decoder_inputs` is a 2-d tensor.
        """
        ori_ndims = decoder_inputs.get_shape().ndims
        if ori_ndims == 2:
            decoder_inputs = tf.expand_dims(decoder_inputs, axis=1)

        # decoder self attention has shape [1, 1, max_target_len, max_target_len]
        decoder_self_attention_bias = layer_utils.lower_triangle_attention_bias(
            tf.shape(decoder_inputs)[1])
        x = decoder_inputs
        if is_training:
            x = tf.nn.dropout(
                decoder_inputs, rate=self.get_config()["layer_postprocess_dropout_rate"])
        for idx, layer in enumerate(self._stacking_layers):
            selfatt_layer = layer[0]
            encdecatt_layer = layer[1]
            ffn_layer = layer[2]
            layer_name = "layer_{}".format(idx)
            layer_cache = None if cache["decoding_states"] is None else cache["decoding_states"][layer_name]
            selfatt_cache = None if layer_cache is None else layer_cache["self_attention"]
            with tf.name_scope(layer_name):
                # self attention layer
                x = selfatt_layer(
                    x,  # x as query
                    bias=decoder_self_attention_bias,
                    cache=selfatt_cache,
                    is_training=is_training,
                    decode_loop_step=decode_loop_step)
                # enc-dec attention layer
                if encdecatt_layer is not None:
                    x = encdecatt_layer(
                        x,  # x as query
                        memory=cache["memory"],  # None indicates self-attention
                        memory_bias=cache["memory_bias"],
                        is_training=is_training)
                # ffn
                x = ffn_layer(x, is_training=is_training)
        outputs = x
        if not self.get_config()["post_normalize"]:
            outputs = self.quant(self._output_norm_layer(x), name="output_ln")
        if ori_ndims == 2:
            outputs = tf.squeeze(outputs, axis=1)
        return outputs
Exemplo n.º 3
0
    def incremental_encode(self, inputs, cache, time=None):
        """ Encoding function for streaming input.

        Args:
            inputs: The embedded input at time t, a float tensor with shape [batch, embedding_dim]
                or [batch, length, embedding_dim]
            cache: A dict containing cached tensors.
            time: The start time of the inputs

        Returns: The incremented encoder output with shape [batch, t+1, dim],
            and the updated cache dict.
        """
        params = self.get_config()
        assert params["attention_monotonic"], (
            "function `incremental_encode` only available when attention_monotonic=True"
        )
        if cache is None:
            cache = {}
        if cache is not None and len(cache) == 0:
            batch_size = tf.shape(inputs)[0]
            for lid in range(params["num_layers"]):
                cache[f"layer_{lid}"] = self._stacking_layers[
                    lid].create_internal_cache()
            cache = tf.nest.map_structure(
                lambda ts: layer_utils.tile_tensor(ts, batch_size, axis=0),
                cache)
        if inputs.get_shape().ndims == 2:
            x = tf.expand_dims(inputs, axis=1)
            x_bias = None
        else:
            x = inputs
            if time is None:
                time = 0
            x_bias = layer_utils.lower_triangle_attention_bias(
                time + tf.shape(x)[1])[:, :, -tf.shape(x)[1]:]
        for idx, layer in enumerate(self._stacking_layers):
            layer_cache = None if cache is None else cache[f"layer_{idx}"]
            x = layer(x, x_bias, layer_cache, is_training=False)
        outputs = x
        if not params["post_normalize"]:
            outputs = self.quant(self._output_norm_layer(x), name="output_ln")
        return outputs, cache
Exemplo n.º 4
0
    def call(self,
             decoder_inputs,
             cache,
             decode_lagging=None,
             is_training=True,
             decode_loop_step=None):
        """ Encodes the inputs.

        Args:
            decoder_inputs: The embedded decoder input, a float tensor with shape
                [batch_size, max_target_length, embedding_dim] or
                [batch_size, embedding_dim] for one decoding step.
            cache: A dictionary, generated from self.create_decoding_internal_cache.
            decode_lagging: The waitk lagging for streaming input when training.
                During inference, it is the lagging for current step.
            is_training: A bool, whether in training mode or not.
            decode_loop_step: An integer, step number of the decoding loop. Used only
                for autoregressive inference with static-shape cache.

        Returns:
            The decoder output with shape [batch_size, max_length, hidden_size]
            when `decoder_inputs` is a 3-d tensor or with shape
            [batch_size, hidden_size] when `decoder_inputs` is a 2-d tensor.
        """
        ori_ndims = decoder_inputs.get_shape().ndims
        if ori_ndims == 2:
            decoder_inputs = tf.expand_dims(decoder_inputs, axis=1)
        memory_bias = cache.get("memory_bias", None)  # [batch, memory_length]
        if memory_bias is not None and decode_lagging is not None:
            if ori_ndims == 3:
                memory_bias = tf.minimum(
                    tf.expand_dims(memory_bias, axis=1),
                    tf.expand_dims(layer_utils.waitk_attention_bias(
                        memory_length=tf.shape(memory_bias)[1],
                        query_length=tf.shape(decoder_inputs)[1],
                        waitk_lagging=decode_lagging),
                                   axis=0))
            else:  # ori_ndims == 2
                memory_bias = tf.minimum(
                    memory_bias,
                    tf.expand_dims(layer_utils.waitk_attention_bias(
                        memory_length=tf.shape(memory_bias)[1],
                        waitk_lagging=decode_lagging),
                                   axis=0))
        # decoder self attention has shape [1, 1, max_target_len, max_target_len]
        decoder_self_attention_bias = layer_utils.lower_triangle_attention_bias(
            tf.shape(decoder_inputs)[1])
        x = decoder_inputs
        if is_training:
            x = tf.nn.dropout(
                decoder_inputs,
                rate=self.get_config()["layer_postprocess_dropout_rate"])
        for idx, layer in enumerate(self._stacking_layers):
            layer_cache = (None if cache["decoding_states"] is None else
                           cache["decoding_states"][f"layer_{idx}"])
            x = self._stacking_layers[idx](x,
                                           decoder_self_attention_bias,
                                           layer_cache,
                                           memory=cache.get("memory", None),
                                           memory_bias=memory_bias,
                                           is_training=is_training,
                                           decode_loop_step=decode_loop_step)
        outputs = x
        if not self.get_config()["post_normalize"]:
            outputs = self.quant(self._output_norm_layer(x), name="output_ln")
        if ori_ndims == 2:
            outputs = tf.squeeze(outputs, axis=1)
        return outputs
Exemplo n.º 5
0
def test_lower_triangle_attention_bias():
    assert_equal_numpy(lower_triangle_attention_bias(5).numpy(),
                       pt_lower_triangle_attention_bias(5).detach().numpy())