def get_ops(self, L): import sys if L not in MultiheadAttention.ops: sparse_dot_sdd_nt = torch_blocksparse.MatMul(self.layout, self.block, 'sdd', trans_a=False, trans_b=True) sparse_dot_dsd_nn = torch_blocksparse.MatMul(self.layout, self.block, 'dsd', trans_a=False, trans_b=False) sparse_softmax = torch_blocksparse.Softmax(self.layout, self.block) MultiheadAttention.ops[L] = (sparse_dot_sdd_nt, sparse_dot_dsd_nn, sparse_softmax) return MultiheadAttention.ops[L]
def run_softmax_triton(x, scale, dx, kp_mask, attn_mask, layout, block): sparse_softmax = torch_blocksparse.Softmax(layout, block, bench=False) dx = dense_to_sparse(dx, layout, block) x = dense_to_sparse(x, layout, block) x.retain_grad() y = sparse_softmax(x, scale=scale, key_padding_mask=kp_mask, key_padding_mask_mode='add', attn_mask=attn_mask, attn_mask_mode='mul') y.backward(dx) dx = x.grad.clone() x.grad.zero_() return x, dx
def get_ops(self, L): if L not in self.__class__.ops: sparsity = self.sparsity layout = self.__class__._make_layout( self.num_attention_heads_per_partition, L // sparsity.block, sparsity.mode, sparsity.stride // sparsity.block, sparsity.unidirectional, sparsity.numverts, sparsity.vertsize) sparse_dot_sdd_nt = torch_blocksparse.MatMul(layout, sparsity.block, 'sdd', trans_a=False, trans_b=True) sparse_dot_dsd_nn = torch_blocksparse.MatMul(layout, sparsity.block, 'dsd', trans_a=False, trans_b=False) sparse_softmax = torch_blocksparse.Softmax(layout, sparsity.block) self.__class__.ops[L] = (sparse_dot_sdd_nt, sparse_dot_dsd_nn, sparse_softmax) return self.__class__.ops[L]
def get_ops(self, H, L): import sys if L not in DeepSpeedSparseSelfAttention.ops: spConfig = self.sparsity_config num_blocks = L // spConfig.block if num_blocks != L / spConfig.block: raise ValueError( f'Sequence length {L} must be dividable by block size {spConfig.block}' ) block_stride = spConfig.stride // spConfig.block if block_stride != spConfig.stride // spConfig.block: raise ValueError( f'Stride {spConfig.stride} must be dividable by block size {spConfig.block}' ) layout = DeepSpeedSparseSelfAttention._make_layout( H, num_blocks, spConfig.mode, block_stride, spConfig.attention, spConfig.numverts, spConfig.vertsize) sparse_dot_sdd_nt = torch_blocksparse.MatMul(layout, spConfig.block, 'sdd', trans_a=False, trans_b=True) sparse_dot_dsd_nn = torch_blocksparse.MatMul(layout, spConfig.block, 'dsd', trans_a=False, trans_b=False) sparse_softmax = torch_blocksparse.Softmax(layout, spConfig.block) DeepSpeedSparseSelfAttention.ops[L] = (sparse_dot_sdd_nt, sparse_dot_dsd_nn, sparse_softmax) return DeepSpeedSparseSelfAttention.ops[L]
def get_ops(self, L): import sys if L not in MultiheadAttention.ops: sparsity = self.sparsity layout = MultiheadAttention._make_layout( self.num_heads, L // sparsity.block, sparsity.mode, sparsity.stride // sparsity.block, sparsity.unidirectional, sparsity.numverts, sparsity.vertsize) sparse_dot_sdd_nt = torch_blocksparse.MatMul(layout, sparsity.block, 'sdd', trans_a=False, trans_b=True) sparse_dot_dsd_nn = torch_blocksparse.MatMul(layout, sparsity.block, 'dsd', trans_a=False, trans_b=False) sparse_softmax = torch_blocksparse.Softmax(layout, sparsity.block) MultiheadAttention.ops[L] = (sparse_dot_sdd_nt, sparse_dot_dsd_nn, sparse_softmax) return MultiheadAttention.ops[L]
def run_bench_softmax(Z, H, M, N, scale, rho, block, dtype, layout=None, repeat=10): layout, x, dx, _, attn_mask, kp_mask = init_inputs(Z, H, M, N, scale, rho, block, dtype, dense_x=False, layout=layout) x = x.clone() dx = dx.clone() # forward function sparse_softmax = torch_blocksparse.Softmax(layout, block, bench=False) y = sparse_softmax(x, scale, None, None, 'add', 'mul') # backward function backward = y.grad_fn.apply backward(dx) x = x.clone() # benchmark time_y = bench(lambda: sparse_softmax(x, scale, None, None, 'add', 'mul'), repeat) time_dx = bench(lambda: backward(dx), repeat) gb_y = (2 * nbytes(x) + nbytes(attn_mask) + nbytes(kp_mask)) * 1e-9 gb_dx = 3 * nbytes(x) * 1e-9 return time_y, time_dx, gb_y, gb_dx
import torch import torch_blocksparse # Z: non-sparse batch dimension # H: sparse batch dimension # M: row dimension # N: column dimension Z, H, M, N, K = 4, 2, 256, 512, 384 a = torch.rand((Z, H, M, K), dtype=torch.float32).cuda() b = torch.rand((Z, H, K, N), dtype=torch.float32).cuda() # create sparsity layout block = 16 layout = torch.randint(0, 2, (H, M // block, N // block)) # create object for Sparse = trans(Dense) x Dense (sdd) # some overhead there as it pre-computes look-up tables # internally needed by GPU kernels dot = torch_blocksparse.MatMul(layout, block, 'sdd', trans_a=True, trans_b=False) c = dot(a, b) # create object for Sparse = softmax(Sparse) softmax = torch_blocksparse.Softmax(layout, block) d = softmax(c)