Example #1
0
    def forward(self,
                query,
                key,
                value,
                attn_mask=None,
                use_cache=False,
                cache=None):
        r"""
        Applies multi-head attention to map queries and a set of key-value pairs
        to outputs.
        """
        key = query if key is None else key
        value = query if value is None else value
        # compute q ,k ,v
        if use_cache is False:
            if self.fuse:
                q, k, v = self._fuse_prepare_qkv(query)
            else:
                q, k, v = self._prepare_qkv(query, key, value, use_cache,
                                            cache)
        else:
            q, k, v, cache = self._prepare_qkv(query, key, value, use_cache,
                                               cache)
        # scale dot product attention
        product = layers.matmul(x=q,
                                y=k,
                                transpose_y=True,
                                alpha=self.head_dim**-0.5)

        # if attn_mask is not None:
        # product = product + attn_mask
        # weights = F.softmax(product)

        weights = incubate.softmax_mask_fuse_upper_triangle(product)

        if self.dropout:
            with get_rng_state_tracker().rng_state('local_seed'):
                weights = F.dropout(weights,
                                    self.dropout,
                                    training=self.training,
                                    mode="upscale_in_train")

        out = tensor.matmul(weights, v)

        # combine heads
        out = tensor.transpose(out, perm=[0, 2, 1, 3])
        out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])

        # project to output
        out = self.out_proj(out)

        outs = [out]
        if self.need_weights:
            outs.append(weights)
        if use_cache:
            outs.append(cache)
        return out if len(outs) == 1 else tuple(outs)
Example #2
0
    def test_dygraph(self):
        for dtype in self.dtypes:
            with fluid.dygraph.guard(fluid.CUDAPlace(0)):
                x_in_np = np.random.random((1, 4, 32, 32)).astype(dtype)
                rst_np = _get_softmax_upper(x_in_np, dtype == 'float16')
                input_x = fluid.dygraph.to_variable(x_in_np)

                rst = incubate.softmax_mask_fuse_upper_triangle(input_x)
                self.assertTrue(np.allclose(rst, rst_np))
Example #3
0
    def test_static(self):
        for dtype in self.dtypes:
            with fluid.program_guard(fluid.Program(), fluid.Program()):
                input_x = fluid.data(name="x",
                                     shape=[1, 4, 32, 32],
                                     dtype=dtype)
                rst = incubate.softmax_mask_fuse_upper_triangle(input_x)

                x_in_np = np.random.random((1, 4, 32, 32)).astype(dtype)
                rst_np = _get_softmax_upper(x_in_np, dtype == 'float16')

                exe = fluid.Executor(fluid.CUDAPlace(0))
                fetches = exe.run(fluid.default_main_program(),
                                  feed={"x": x_in_np},
                                  fetch_list=[rst])
                self.assertTrue(np.allclose(fetches[0], rst_np))