Exemplo n.º 1
0
 def get_ops(self, L):
     import sys
     if L not in MultiheadAttention.ops:
         sparse_dot_sdd_nt = torch_blocksparse.MatMul(self.layout, self.block, 'sdd', trans_a=False, trans_b=True)
         sparse_dot_dsd_nn = torch_blocksparse.MatMul(self.layout, self.block, 'dsd', trans_a=False, trans_b=False)
         sparse_softmax = torch_blocksparse.Softmax(self.layout, self.block)
         MultiheadAttention.ops[L] = (sparse_dot_sdd_nt, sparse_dot_dsd_nn, sparse_softmax)
     return MultiheadAttention.ops[L]
def run_softmax_triton(x, scale, dx, kp_mask, attn_mask, layout, block):
    sparse_softmax = torch_blocksparse.Softmax(layout, block, bench=False)
    dx = dense_to_sparse(dx, layout, block)
    x = dense_to_sparse(x, layout, block)
    x.retain_grad()
    y = sparse_softmax(x,
                       scale=scale,
                       key_padding_mask=kp_mask,
                       key_padding_mask_mode='add',
                       attn_mask=attn_mask,
                       attn_mask_mode='mul')
    y.backward(dx)
    dx = x.grad.clone()
    x.grad.zero_()
    return x, dx
Exemplo n.º 3
0
 def get_ops(self, L):
     if L not in self.__class__.ops:
         sparsity = self.sparsity
         layout = self.__class__._make_layout(
             self.num_attention_heads_per_partition, L // sparsity.block,
             sparsity.mode, sparsity.stride // sparsity.block,
             sparsity.unidirectional, sparsity.numverts, sparsity.vertsize)
         sparse_dot_sdd_nt = torch_blocksparse.MatMul(layout,
                                                      sparsity.block,
                                                      'sdd',
                                                      trans_a=False,
                                                      trans_b=True)
         sparse_dot_dsd_nn = torch_blocksparse.MatMul(layout,
                                                      sparsity.block,
                                                      'dsd',
                                                      trans_a=False,
                                                      trans_b=False)
         sparse_softmax = torch_blocksparse.Softmax(layout, sparsity.block)
         self.__class__.ops[L] = (sparse_dot_sdd_nt, sparse_dot_dsd_nn,
                                  sparse_softmax)
     return self.__class__.ops[L]
Exemplo n.º 4
0
    def get_ops(self, H, L):
        import sys
        if L not in DeepSpeedSparseSelfAttention.ops:
            spConfig = self.sparsity_config

            num_blocks = L // spConfig.block
            if num_blocks != L / spConfig.block:
                raise ValueError(
                    f'Sequence length {L} must be dividable by block size {spConfig.block}'
                )

            block_stride = spConfig.stride // spConfig.block
            if block_stride != spConfig.stride // spConfig.block:
                raise ValueError(
                    f'Stride {spConfig.stride} must be dividable by block size {spConfig.block}'
                )

            layout = DeepSpeedSparseSelfAttention._make_layout(
                H, num_blocks, spConfig.mode, block_stride, spConfig.attention,
                spConfig.numverts, spConfig.vertsize)

            sparse_dot_sdd_nt = torch_blocksparse.MatMul(layout,
                                                         spConfig.block,
                                                         'sdd',
                                                         trans_a=False,
                                                         trans_b=True)

            sparse_dot_dsd_nn = torch_blocksparse.MatMul(layout,
                                                         spConfig.block,
                                                         'dsd',
                                                         trans_a=False,
                                                         trans_b=False)

            sparse_softmax = torch_blocksparse.Softmax(layout, spConfig.block)

            DeepSpeedSparseSelfAttention.ops[L] = (sparse_dot_sdd_nt,
                                                   sparse_dot_dsd_nn,
                                                   sparse_softmax)
        return DeepSpeedSparseSelfAttention.ops[L]
Exemplo n.º 5
0
 def get_ops(self, L):
     import sys
     if L not in MultiheadAttention.ops:
         sparsity = self.sparsity
         layout = MultiheadAttention._make_layout(
             self.num_heads, L // sparsity.block, sparsity.mode,
             sparsity.stride // sparsity.block, sparsity.unidirectional,
             sparsity.numverts, sparsity.vertsize)
         sparse_dot_sdd_nt = torch_blocksparse.MatMul(layout,
                                                      sparsity.block,
                                                      'sdd',
                                                      trans_a=False,
                                                      trans_b=True)
         sparse_dot_dsd_nn = torch_blocksparse.MatMul(layout,
                                                      sparsity.block,
                                                      'dsd',
                                                      trans_a=False,
                                                      trans_b=False)
         sparse_softmax = torch_blocksparse.Softmax(layout, sparsity.block)
         MultiheadAttention.ops[L] = (sparse_dot_sdd_nt, sparse_dot_dsd_nn,
                                      sparse_softmax)
     return MultiheadAttention.ops[L]
def run_bench_softmax(Z,
                      H,
                      M,
                      N,
                      scale,
                      rho,
                      block,
                      dtype,
                      layout=None,
                      repeat=10):
    layout, x, dx, _, attn_mask, kp_mask = init_inputs(Z,
                                                       H,
                                                       M,
                                                       N,
                                                       scale,
                                                       rho,
                                                       block,
                                                       dtype,
                                                       dense_x=False,
                                                       layout=layout)
    x = x.clone()
    dx = dx.clone()
    # forward function
    sparse_softmax = torch_blocksparse.Softmax(layout, block, bench=False)
    y = sparse_softmax(x, scale, None, None, 'add', 'mul')
    # backward function
    backward = y.grad_fn.apply
    backward(dx)
    x = x.clone()
    # benchmark
    time_y = bench(lambda: sparse_softmax(x, scale, None, None, 'add', 'mul'),
                   repeat)
    time_dx = bench(lambda: backward(dx), repeat)
    gb_y = (2 * nbytes(x) + nbytes(attn_mask) + nbytes(kp_mask)) * 1e-9
    gb_dx = 3 * nbytes(x) * 1e-9
    return time_y, time_dx, gb_y, gb_dx
Exemplo n.º 7
0
import torch
import torch_blocksparse

# Z: non-sparse batch dimension
# H: sparse batch dimension
# M: row dimension
# N: column dimension
Z, H, M, N, K = 4, 2, 256, 512, 384
a = torch.rand((Z, H, M, K), dtype=torch.float32).cuda()
b = torch.rand((Z, H, K, N), dtype=torch.float32).cuda()
# create sparsity layout
block = 16
layout = torch.randint(0, 2, (H, M // block, N // block))
# create object for Sparse = trans(Dense) x Dense (sdd)
# some overhead there as it pre-computes look-up tables
# internally needed by GPU kernels
dot = torch_blocksparse.MatMul(layout,
                               block,
                               'sdd',
                               trans_a=True,
                               trans_b=False)
c = dot(a, b)
# create object for Sparse = softmax(Sparse)
softmax = torch_blocksparse.Softmax(layout, block)
d = softmax(c)