コード例 #1
0
    def _prepare_qkv(self, query, key, value, use_cache=False, cache=None):
        r"""
        Prapares linear projected queries, keys and values for usage of subsequnt
        multiple parallel attention. If `cache` is not None, using cached results
        to reduce redundant calculations.

        """
        q = self.q_proj(query)
        q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
        q = tensor.transpose(x=q, perm=[0, 2, 1, 3])

        if isinstance(cache, self.StaticCache):
            # for encoder-decoder attention in inference and has cached
            k, v = cache.k, cache.v
        else:
            k, v = self.compute_kv(key, value)

        if isinstance(cache, self.Cache):
            # for decoder self-attention in inference
            k = tensor.concat([cache.k, k], axis=2)
            v = tensor.concat([cache.v, v], axis=2)
        if use_cache is True:
            cache = self.Cache(k, v)

        return (q, k, v) if use_cache is False else (q, k, v, cache)
コード例 #2
0
 def _prepare_qkv(self, query, key, value, use_cache=False, cache=None):
     """
     Prapares linear projected queries, keys and values for usage of subsequnt
     multiple parallel attention. If `cache` is not None, using cached results
     to reduce redundant calculations.
     """
     q = self.q_proj(query)
     if _global_parallel_strategy == "mp":
         auto.shard_tensor(self.q_proj.weight,
                           dist_attr={
                               "process_mesh": _global_process_mesh,
                               "dims_mapping": [-1, 0]
                           })
     elif _global_parallel_strategy == "dp_mp":
         auto.shard_tensor(self.q_proj.weight,
                           dist_attr={
                               "process_mesh": _global_process_mesh,
                               "dims_mapping": [-1, 1]
                           })
     elif _global_parallel_strategy == "mp_pp":
         auto.shard_tensor(self.q_proj.weight,
                           dist_attr={
                               "process_mesh":
                               MPPP_MESH_LIST[self.mesh_idx],
                               "dims_mapping": [-1, 0]
                           })
     elif _global_parallel_strategy == "dp_mp_pp":
         auto.shard_tensor(self.q_proj.weight,
                           dist_attr={
                               "process_mesh":
                               DPMPPP_MESH_LIST[self.mesh_idx],
                               "dims_mapping": [-1, 1]
                           })
     q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
     q = tensor.transpose(x=q, perm=[0, 2, 1, 3])
     if isinstance(cache, self.StaticCache):
         # for encoder-decoder attention in inference and has cached
         k, v = cache.k, cache.v
     else:
         k, v = self.compute_kv(key, value)
     if isinstance(cache, self.Cache):
         # for decoder self-attention in inference
         k = tensor.concat([cache.k, k], axis=2)
         v = tensor.concat([cache.v, v], axis=2)
     if use_cache is True:
         cache = self.Cache(k, v)
     return (q, k, v) if use_cache is False else (q, k, v, cache)
コード例 #3
0
ファイル: __init__.py プロジェクト: zzz2010/paddorch
    def probs(self, value):
        """Probabilities of the given category (``value``).
        If ``logits`` is 2-D or higher dimension, the last dimension will be regarded as
        category, and the others represents the different distributions.
        At the same time, if ``vlaue`` is 1-D Tensor, ``value`` will be broadcast to the
        same number of distributions as ``logits``.
        If ``value`` is not 1-D Tensor, ``value`` should have the same number distributions
        with ``logits. That is, ``value[:-1] = logits[:-1]``.
        Args:
            value (Tensor): The input tensor represents the selected category index.
        Returns:
            Tensor: probability according to the category index.

        Examples:
            .. code-block:: python
                import paddle
                from paddle.distribution import Categorical
                paddle.seed(100) # on CPU device
                x = paddle.rand([6])
                print(x)
                # [0.5535528  0.20714243 0.01162981
                #  0.51577556 0.36369765 0.2609165 ]
                cat = Categorical(x)
                value = paddle.to_tensor([2,1,3])
                cat.probs(value)
                # [0.00608027 0.108298 0.269656]
        """
        name = self.name + '_probs'

        dist_sum = nn.reduce_sum(self.logits, dim=-1, keep_dim=True)
        prob = self.logits / dist_sum

        shape = list(prob.shape)
        value_shape = list(value.shape)
        if len(shape) == 1:
            num_value_in_one_dist = np.prod(value_shape)
            index_value = nn.reshape(value, [num_value_in_one_dist, 1])
            index = index_value
        else:
            num_dist = np.prod(shape[:-1])
            num_value_in_one_dist = value_shape[-1]
            prob = nn.reshape(prob, [num_dist, shape[-1]])
            if len(value_shape) == 1:
                value = nn.expand(value, [num_dist])
                value_shape = shape[:-1] + value_shape
            index_value = nn.reshape(value, [num_dist, -1, 1])
            if shape[:-1] != value_shape[:-1]:
                raise ValueError(
                    "shape of value {} must match shape of logits {}".format(
                        str(value_shape[:-1]), str(shape[:-1])))

            index_prefix = nn.unsqueeze(arange(num_dist,
                                               dtype=index_value.dtype),
                                        axes=-1)
            index_prefix = nn.expand(index_prefix, [1, num_value_in_one_dist])
            index_prefix = nn.unsqueeze(index_prefix, axes=-1)

            if index_value.dtype != index_prefix.dtype:
                tensor.cast(index_prefix, dtype=index_value.dtype)
            index = concat([index_prefix, index_value], axis=-1)

        # value is the category index to search for the corresponding probability.
        select_prob = gather_nd(prob, index)
        return nn.reshape(select_prob, value_shape, name=name)