Example #1
0
 def _dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut,
                 num_locks, width, bench, time):
     # shapes / dtypes
     AS0 = spdims[0]
     AS1 = block * spdims[2 if trans_a else 1]
     AS2 = block * spdims[1 if trans_a else 2]
     BS0 = b.size(0)
     BS1 = b.size(1)
     BS2 = b.size(3 if trans_b else 2)
     BS3 = b.size(2 if trans_b else 3)
     dtype = a.dtype
     # kernel
     key = (block, a.dtype, b.dtype, trans_a, trans_b, trans_c)
     if key not in _sparse_matmul.dsd_cache:
         defines = {
             'TM': block,
             'TN': 128,
             'TK': 8,
             'TYPE': dtype,
             'STRIDE_AM': 1 if trans_a else block,
             'STRIDE_AK': block if trans_a else 1,
             'STRIDE_BN': 'ldb' if trans_b else '1',
             'STRIDE_BK': '1' if trans_b else 'ldb',
             'STRIDE_CM': '1' if trans_c else 'ldc',
             'STRIDE_CN': 'ldc' if trans_c else '1',
             'NAME': 'dsd_kernel',
             'DSD': True
         }
         _sparse_matmul.dsd_cache[key] = triton.kernel(src,
                                                       defines=defines,
                                                       num_warps=[4])
     kernel = _sparse_matmul.dsd_cache[key]
     # output
     CS0 = BS0
     CS1 = BS1
     CS2 = BS3 if trans_c else AS1
     CS3 = AS1 if trans_c else BS3
     locks = _sparse_matmul.get_locks(2 * BS0 * BS3 // 32 * num_locks)
     c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
     time[0] = kernel(
         a,
         b,
         c,
         block,
         b.stride(2),
         c.stride(2),
         a.stride(0),
         b.stride(0),
         c.stride(0),
         a.stride(1),
         b.stride(1),
         c.stride(1),
         BS3,
         0,
         0,
         lut,
         locks,
         num_locks,
         grid=lambda opt: [width, triton.cdiv(BS3, opt.d('TN')), BS0],
         bench=bench)
     return c
Example #2
0
 def __init__(self, einsum, dtype, stride_a, stride_b, stride_c, shape_a, shape_b, shape_c, arrays):
     # parse symbols
     expr_a, expr_bc = einsum.split(",")
     expr_b, expr_c  = expr_bc.split("->")
     subscripted = []
     sym_a = _einsum.parse_expr(expr_a, subscripted)
     sym_b = _einsum.parse_expr(expr_b, subscripted)
     sym_c = _einsum.parse_expr(expr_c, subscripted)
     # parse axes
     axes_b, axes_m, axes_k = _einsum.parse_axes(sym_a, sym_b, sym_c, subscripted)
     _, axes_n, _           = _einsum.parse_axes(sym_b, sym_a, sym_c, subscripted)
     axes = axes_b + axes_m + axes_n + axes_k
     # check dimensions
     dims_a  = dict(zip(sym_a, shape_a))
     dims_b  = dict(zip(sym_b, shape_b))
     dims_c  = dict(zip(sym_c, shape_c))
     for axes in [axes_b, axes_k]:
         for d in axes:
             dim_a = dims_a[d] if d in sym_a else None
             dim_b = dims_b[d] if d in sym_b else None
             if dim_a and dim_b and dim_a != dim_b:
                 raise ValueError(f'incompatible dimension {d}'
                                 f' (a: {dim_a}; b: {dim_b})')
     dims = dict()
     dims.update(dims_a)
     dims.update(dims_b)
     dims.update(dims_c)
     # look-up tables
     TK = 16 if dtype == triton.fw.torch.float16 else 8
     arrays = [(x, arrays[x]) for x in subscripted]
     delta_a, lut_mode_a = _einsum.make_delta(axes_k, TK, stride_a, dims, sym_a, arrays)
     delta_b, lut_mode_b = _einsum.make_delta(axes_k, TK, stride_b, dims, sym_b, arrays)
     # hash for recompilation
     stride_a_multiple = max([x for x in [1, 2, 4, 8] if shape_a[-1] % x == 0])
     stride_b_multiple = max([x for x in [1, 2, 4, 8] if shape_b[-1] % x == 0])
     stride_c_multiple = max([x for x in [1, 2, 4, 8] if shape_c[-1] % x == 0])
     name = f'{expr_a}_{expr_b}_{expr_c}_{lut_mode_a}_{lut_mode_b}'\
         f'_{stride_a_multiple}_{stride_b_multiple}_{stride_c_multiple}'
     # recompile if necessary
     cache = _einsum.instance.kernel_cache
     if name not in cache:
         cachesize = len(cache)
         cache[name] = _einsum.make_kernel(f'__einsum{cachesize}', 
                                                 sym_a, sym_b, sym_c, 
                                                 axes_m, axes_n, axes_k, axes_b, 
                                                 stride_a_multiple, stride_b_multiple, stride_c_multiple,
                                                 lut_mode_a, lut_mode_b,
                                                 delta_a, delta_b,
                                                 subscripted)
     self.kernel = cache[name]
     # Initialize locks
     if _einsum.instance.locks is None:
         _einsum.instance.locks = torch.zeros(2*1024*1024, dtype=torch.int32).cuda()
     # Kernel arguments
     dim_m = [dims[d] for d in axes_m]
     dim_n = [dims[d] for d in axes_n]
     dim_k = [dims[d] for d in axes_k]
     dim_b = [dims[d] for d in axes_b]
     M = reduce(mul, dim_m, 1)
     N = reduce(mul, dim_n, 1)
     K = reduce(mul, dim_k, 1)
     B = reduce(mul, dim_b, 1)
     stride_a = list(stride_a[:-1])
     stride_b = list(stride_b[:-1])
     stride_c = list(stride_c[:-1])
     arrays = [torch.from_numpy(x).cuda() for _, x in arrays]
     alpha = 1.
     div_m = 1
     self.args = [None, None, None,
                  _einsum.instance.locks, 
                  alpha, M, N, K, div_m] +\
                  dim_m + dim_n +  dim_k + dim_b +\
                  stride_a + stride_b + stride_c
     if lut_mode_a != _einsum.LUT_MODE.CONSTANT:
         delta_a = delta_a[0] if lut_mode_a == _einsum.LUT_MODE.SCALAR else torch.from_numpy(delta_a).cuda()
         self.args += [delta_a]
     if lut_mode_b != _einsum.LUT_MODE.CONSTANT:
         delta_b = delta_b[0] if lut_mode_b == _einsum.LUT_MODE.SCALAR else torch.from_numpy(delta_b).cuda()
         self.args += [delta_b]
     self.args += arrays
     self.grid = lambda opt: [triton.cdiv(M, opt.d('TM')) * 
                             triton.cdiv(N, opt.d('TN')),
                             triton.cdiv(B, opt.d('TB')),
                             opt.d('TZ')]
     # position of dynamic arguments
     self.pos_a = 0
     self.pos_b = 1
     self.pos_c = 2
     # pre-processor macros
     TM = [16] + [x for x in [32, 64, 128] if x <= M]
     TN = [16] + [x for x in [32, 64, 128] if x <= N]
     TB = [x for x in [1, 2, 4] if x <= B]
     MAX_GZ = K // 2048
     MIN_GM = M // max(TM)
     MIN_GN = N // max(TN)
     MIN_GB = B // max(TB)
     TZ = [x for x in [1, 2, 4, 8, 16, 32] \
             if x < MAX_GZ and x*MIN_GM*MIN_GN*MIN_GB < 256]
     TZ = [1] if not TZ else [TZ[-1], TZ[-1]*2]
     self.macros = {  'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype }
     # information on compute
     self.dtype = dtype
     self.flops = 2 * B * M * N * K
     self.sym_a = sym_a
     self.sym_b = sym_b
     self.sym_c = sym_c
     # save equivalent mat-mul dimensions
     self.matmul_B = B
     self.matmul_M = M
     self.matmul_N = N
     self.matmul_K = K
Example #3
0
 def _dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut,
                 num_locks, width, packs, bench, time):
     # shapes / dtypes
     AS0 = a.size(0)
     AS1 = a.size(1)
     AS2 = a.size(3 if trans_a else 2)
     AS3 = a.size(2 if trans_a else 3)
     BS0 = spdims[0]
     BS1 = block * spdims[2 if trans_b else 1]
     BS2 = block * spdims[1 if trans_b else 2]
     dtype = a.dtype
     # kernel
     key = (block, a.dtype, b.dtype, trans_a, trans_b, trans_c)
     if key not in _sparse_matmul.dds_cache:
         TM = [64, 128] if dtype == torch.float32 else [64, 128, 256]
         TK = [8] if dtype == torch.float32 else [16]
         defines = {
             'TM': TM,
             'TN': block,
             'TK': TK,
             'BLOCK': block,
             'TYPE': dtype,
             'STRIDE_AM': 1 if trans_a else 'lda',
             'STRIDE_AK': 'lda' if trans_a else 1,
             'STRIDE_BN': block if trans_b else 1,
             'STRIDE_BK': 1 if trans_b else block,
             'STRIDE_CM': '1' if trans_c else 'ldc',
             'STRIDE_CN': 'ldc' if trans_c else '1',
             'NAME': 'dds_kernel',
             'DDS': True
         }
         _sparse_matmul.dds_cache[key] = triton.kernel(src,
                                                       defines=defines,
                                                       num_warps=[4])
     kernel = _sparse_matmul.dds_cache[key]
     # output
     CS0 = AS0
     CS1 = AS1
     CS2 = BS2 if trans_c else AS2
     CS3 = AS2 if trans_c else BS2
     locks = _sparse_matmul.get_locks(2 * AS0 * AS2 // 32 * num_locks,
                                      a.device)
     c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
     time[0] = kernel(
         a,
         b,
         c,
         a.stride(2),
         block,
         c.stride(2),
         a.stride(0),
         b.stride(0),
         c.stride(0),
         a.stride(1),
         b.stride(1),
         c.stride(1),
         AS2,
         BS2,
         0,
         0,
         lut,
         locks,
         num_locks,
         grid=lambda opt: [width, triton.cdiv(AS2, opt.d('TM')), AS0],
         bench=bench)
     return c
Example #4
0
 def do_work(x, in_order, out_order):
     x_inner_mul = _permute.multiple_of(x.shape['NCHW'.index(in_order[-1])])
     y_inner_mul = _permute.multiple_of(x.shape['NCHW'.index(
         out_order[-1])])
     key = (x.dtype, in_order, out_order, x_inner_mul, y_inner_mul)
     if key not in _permute.kernels:
         TN = [32] if in_order[-1] == 'N' or out_order[-1] == 'N' else 1
         TC = [32] if in_order[-1] == 'C' or out_order[-1] == 'C' else 1
         THW = [32] if in_order[-1] == 'W' or out_order[-1] == 'W' else 1
         defines = {
             'NAME':
             f'permute_{in_order}_{out_order}_{x_inner_mul}_{y_inner_mul}',
             'TYPE': x.dtype,
             # stride multiple for X
             'M_STRIDE_XN': 1 if in_order[-1] == 'N' else x_inner_mul,
             'M_STRIDE_XC': 1 if in_order[-1] == 'N' else x_inner_mul,
             'M_STRIDE_XHW': 1 if in_order[-1] == 'N' else x_inner_mul,
             # stride multiple for Y
             'M_STRIDE_YN': 1 if out_order[-1] == 'N' else y_inner_mul,
             'M_STRIDE_YC': 1 if out_order[-1] == 'N' else y_inner_mul,
             'M_STRIDE_YHW': 1 if out_order[-1] == 'N' else y_inner_mul,
             # strides for X
             'STRIDE_XN': 1 if in_order[-1] == 'N' else 'stride_xn',
             'STRIDE_XC': 1 if in_order[-1] == 'C' else 'stride_xc',
             'STRIDE_XHW': 1 if in_order[-1] == 'W' else 'stride_xhw',
             # strides for Y
             'STRIDE_YN': 1 if out_order[-1] == 'N' else 'stride_yn',
             'STRIDE_YC': 1 if out_order[-1] == 'C' else 'stride_yc',
             'STRIDE_YHW': 1 if out_order[-1] == 'W' else 'stride_yhw',
             # tile parameters
             'TN': TN,
             'TC': TC,
             'THW': THW
         }
         _permute.kernels[key] = triton.kernel(src,
                                               defines=defines,
                                               num_warps=[4])
     kernel = _permute.kernels[key]
     N, C, H, W = x.shape
     y = torch.empty_strided(x.shape,
                             _permute.strides(N, C, H, W, out_order),
                             device=x.device,
                             dtype=x.dtype)
     stride_xn, stride_xc, _, stride_xhw = x.stride()
     stride_yn, stride_yc, _, stride_yhw = y.stride()
     grid = lambda opt: (triton.cdiv(N, opt.d('TN')),
                         triton.cdiv(C, opt.d('TC')),
                         triton.cdiv(H * W, opt.d('THW')))
     kernel(x,
            y,
            N,
            C,
            H * W,
            stride_xn,
            stride_xc,
            stride_xhw,
            stride_yn,
            stride_yc,
            stride_yhw,
            grid=grid)
     return y
Example #5
0
 def _dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut,
                 num_locks, width, packs):
     # shapes / dtypes
     AS0 = a.size(0)
     AS1 = a.size(1)
     AS2 = a.size(3 if trans_a else 2)
     AS3 = a.size(2 if trans_a else 3)
     BS0 = spdims[0]
     BS1 = block * spdims[2 if trans_b else 1]
     BS2 = block * spdims[1 if trans_b else 2]
     dtype = a.dtype
     # kernel
     key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c)
     if key not in _matmul.dds_cache:
         defines = {
             'TM': 128,
             'TN': block,
             'TK': 16,
             'BLOCK': block,
             'TYPE': dtype,
             'STRIDE_AM': 1 if trans_a else 'lda',
             'STRIDE_AK': 'lda' if trans_a else 1,
             'STRIDE_BN': block if trans_b else 1,
             'STRIDE_BK': 1 if trans_b else block,
             'STRIDE_CM': '1' if trans_c else 'ldc',
             'STRIDE_CN': 'ldc' if trans_c else '1',
             'NAME': 'dds_kernel',
             'DDS': True
         }
         _matmul.dds_cache[key] = triton.kernel(src,
                                                device=a.device,
                                                defines=defines)
     kernel = _matmul.dds_cache[key]
     # output
     CS0 = AS0
     CS1 = AS1
     CS2 = BS2 if trans_c else AS2
     CS3 = AS2 if trans_c else BS2
     locks = _matmul.get_locks(2 * AS0 * AS2 // 32 * num_locks, a.device)
     c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
     kernel(a.data_ptr(),
            b.data_ptr(),
            c.data_ptr(),
            a.stride(2),
            block,
            c.stride(2),
            a.stride(0),
            b.stride(0),
            c.stride(0),
            a.stride(1),
            b.stride(1),
            c.stride(1),
            AS2,
            BS2,
            0,
            0,
            lut.data_ptr(),
            locks.data_ptr(),
            num_locks,
            grid=lambda opt: [width, triton.cdiv(AS2, opt.TM), AS0])
     return c
Example #6
0
    def _dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut,
                    num_locks, width, packs, bench, time):
        global triton
        if triton is None:
            triton = importlib.import_module('triton')

        # shapes / dtypes
        AS0 = spdims[0]
        AS1 = block * spdims[2 if trans_a else 1]
        AS2 = block * spdims[1 if trans_a else 2]
        BS0 = b.size(0)
        BS1 = b.size(1)
        BS2 = b.size(3 if trans_b else 2)
        BS3 = b.size(2 if trans_b else 3)
        dtype = a.dtype
        # kernel

        meta = {
            'TM': block,
            'TN': 128,
            'TK': 16,
            'BLOCK': block,
            'TZ': 1,
            'SDD': False,
            'DSD': True,
            'DDS': False
        }
        # output
        CS0 = BS0
        CS1 = BS1
        CS2 = BS3 if trans_c else AS1
        CS3 = AS1 if trans_c else BS3
        locks = _sparse_matmul.get_locks(2 * BS0 * BS3 // 32 * num_locks,
                                         a.device)
        c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
        grid = lambda meta: [width, triton.cdiv(BS3, meta['TN']), BS0]
        _kernel[grid](a,
                      b,
                      c,
                      a.stride(0),
                      a.stride(1),
                      a.stride(3 if trans_a else 2),
                      a.stride(2 if trans_a else 3),
                      b.stride(0),
                      b.stride(1),
                      b.stride(3 if trans_b else 2),
                      b.stride(2 if trans_b else 3),
                      c.stride(0),
                      c.stride(1),
                      c.stride(2),
                      c.stride(3),
                      BS3,
                      AS1,
                      0,
                      0,
                      lut,
                      locks,
                      num_locks,
                      num_warps=4,
                      **meta)
        return c
Example #7
0
    def forward(ctx, x, scale, rpe, key_padding_mask, attn_mask, kp_mask_mode,
                attn_mask_mode, spdims, block, lut, num_blocks, maxlut, bench,
                time):
        apply_scale = False if scale == 1.0 else True

        # handle None rpe
        if rpe is None:
            apply_rpe = False
            stride_zrpe, stride_hrpe, stride_srpe = 0, 0, 0
            rpe = torch.empty(0, dtype=x.dtype, device=x.device)
        else:
            apply_rpe = True
            stride_zrpe, stride_hrpe, stride_srpe = rpe.stride(0), rpe.stride(
                1), rpe.stride(2)

        # handle None key_padding_mask
        if key_padding_mask is None:
            apply_kp_mask = False
            stride_zkpm = 0
            key_padding_mask = torch.empty(0, dtype=x.dtype, device=x.device)
        else:
            apply_kp_mask = True
            stride_zkpm = key_padding_mask.stride(0)

        # handle None attention_mask
        if attn_mask is None:
            apply_attn_mask = False
            stride_zattnm = 0
            attn_mask = torch.empty(0, dtype=x.dtype, device=x.device)
        else:
            apply_attn_mask = True
            stride_zattnm = attn_mask.stride(0)

        # run kernel
        kernel = _sparse_softmax.make_kernel(fwd_kernels, softmax_fwd,
                                             maxlut * block, x.dtype, block,
                                             apply_scale, apply_rpe,
                                             apply_kp_mask, apply_attn_mask,
                                             kp_mask_mode, attn_mask_mode)
        M = x.shape[0]
        grid = lambda opt: [
            triton.cdiv(spdims[0] * spdims[1] * block, opt.d('TM')), M
        ]

        # run kernel
        time[0] = kernel(x, scale, lut, rpe, key_padding_mask, attn_mask,\
                         num_blocks, maxlut,\
                         x.stride(0),\
                         stride_zrpe, stride_hrpe, stride_srpe,\
                         stride_zkpm, stride_zattnm,\
                         grid=grid, bench=bench)
        # save to context
        ctx.mark_dirty(x)
        ctx.save_for_backward(x, lut)
        ctx.spdims = spdims
        ctx.block = block
        ctx.maxlut = maxlut
        ctx.scale = scale
        ctx.apply_scale = apply_scale
        ctx.apply_rpe = apply_rpe
        ctx.apply_kp_mask = apply_kp_mask
        ctx.apply_attn_mask = apply_attn_mask
        ctx.kp_mask_mode = kp_mask_mode
        ctx.attn_mask_mode = attn_mask_mode
        return x
Example #8
0
 def _call(a, b, pad_d, pad_h, pad_w, stride_d, stride_h, stride_w,
           upsample_d, upsample_h, upsample_w, a_layout, b_layout,
           c_layout):
     # input shapes
     shape_a = list(triton.shape(a))
     shape_b = list(triton.shape(b))
     dim = len(shape_a) - 2
     # indices
     an, ac, ad, ah, aw = [a_layout.find(x) for x in 'ncdhw']
     bk, bc, bd, bh, bw = [b_layout.find(x) for x in 'kctrs']
     cn, ck, cd, ch, cw = [c_layout.find(x) for x in 'nkdhw']
     # extract shapes
     if dim == 2:
         shape_a.insert(ad, 1)
     if dim == 2:
         shape_b.insert(bd, 1)
     # output shape
     shape_c = [0] * 5
     shape_c[cn] = shape_a[an]
     shape_c[ck] = shape_b[bk]
     shape_c[cd] = (shape_a[ad] * upsample_d - shape_b[bd] + 1 + 2 * pad_d +
                    stride_d - 1) // stride_d
     shape_c[ch] = (shape_a[ah] * upsample_h - shape_b[bh] + 1 + 2 * pad_h +
                    stride_h - 1) // stride_h
     shape_c[cw] = (shape_a[aw] * upsample_w - shape_b[bw] + 1 + 2 * pad_w +
                    stride_w - 1) // stride_w
     # strides
     stride_a = _conv._extract_strides(shape_a)
     stride_b = _conv._extract_strides(shape_b)
     stride_c = _conv._extract_strides(shape_c)
     # tiling parameters
     TM = [32]
     TN = [32]
     TK = 8
     # pointer deltas for a
     delta_a = _conv._delta_a(upsample_d, upsample_h, upsample_w, bc, bd,
                              bh, bw, ac, ad, ah, aw, stride_a, shape_b, TK)
     delta_a = triton.fw.torch.from_numpy(delta_a).cuda()
     # delta increments for a
     inc_a = np.arange(delta_a.shape[-1] - TK, dtype=np.int32)
     inc_a = ((inc_a + TK) % inc_a.size) - inc_a
     inc_a = triton.fw.torch.from_numpy(inc_a).cuda()
     # allocate output
     if dim == 2:
         shape_c.pop(cd)
     c = triton.empty(shape_c, dtype=a.dtype)
     if dim == 2:
         shape_c.insert(cd, 1)
     # execute kernel
     trans_b = False
     is_wgrad = False
     is_blut = False
     macros = {
         'UPAR': 'stride_h' if is_wgrad else '1',
         'UPAS': '******' if is_wgrad else '1',
         'UPAH': '' if is_wgrad else 'stride_h',
         'UPAW': '' if is_wgrad else 'stride_w',
         'LUT_SIZE': delta_a.shape[-1],
         'TM': TM,
         'TN': TN,
         'TK': TK,
         'A_TYPE': 'float',
         'B_TYPE': 'float'
     }
     MATMUL_M = shape_c[cn] * shape_c[cd] * shape_c[ch] * shape_c[cw]
     MATMUL_N = shape_c[ck]
     MATMUL_K = shape_b[bc] * shape_b[bd] * shape_b[bh] * shape_b[bw]
     _conv.kernel(
         a,
         b,
         c,
         # matrix multiplication shapes
         MATMUL_M,
         MATMUL_N,
         MATMUL_K,
         # shapes for a
         shape_a[ah],
         shape_a[aw],
         # shapes for b
         shape_b[bh],
         shape_b[bw],
         # chapes for c
         shape_c[ch],
         shape_c[cw],
         shape_c[cn],
         # strides for a
         stride_a[an],
         stride_a[ac],
         stride_a[ad + 0],
         stride_a[ad + 1],
         stride_a[ad + 2],
         # strides for b
         stride_b[bc],
         stride_b[bd + 0],
         stride_b[bd + 1],
         stride_b[bd + 2],
         stride_b[bk],
         # strides for c
         stride_c[cn],
         stride_c[ck],
         stride_c[cd],
         stride_c[cd + 1],
         stride_c[cd + 2],
         # padding
         pad_h,
         pad_w,
         # striding
         stride_h,
         stride_w,
         # upsampling
         upsample_h,
         upsample_w,
         0,
         0,
         0,
         0,
         0,
         0,
         # look-up table
         delta_a,
         inc_a,
         lambda opt: [
             triton.cdiv(MATMUL_M, opt.d('TM')),
             triton.cdiv(MATMUL_N, opt.d('TN'))
         ],
         **macros)
     return c
Example #9
0
 def _dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut,
                 num_locks, width, packs):
     # shapes / dtypes
     AS0 = spdims[0]
     AS1 = block * spdims[2 if trans_a else 1]
     AS2 = block * spdims[1 if trans_a else 2]
     BS0 = b.size(0)
     BS1 = b.size(1)
     BS2 = b.size(3 if trans_b else 2)
     BS3 = b.size(2 if trans_b else 3)
     dtype = a.dtype
     # kernel
     key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c)
     if key not in _matmul.dsd_cache:
         TN = [64, 128] if dtype == torch.float32 else [64, 128]
         TK = [8] if dtype == torch.float32 else [16]
         defines = {
             'TM': block,
             'TN': TN,
             'TK': TK,
             'BLOCK': block,
             'TYPE': dtype,
             'STRIDE_AM': 1 if trans_a else block,
             'STRIDE_AK': block if trans_a else 1,
             'STRIDE_BN': 'ldb' if trans_b else '1',
             'STRIDE_BK': '1' if trans_b else 'ldb',
             'STRIDE_CM': '1' if trans_c else 'ldc',
             'STRIDE_CN': 'ldc' if trans_c else '1',
             'NAME': 'dsd_kernel',
             'DSD': True
         }
         _matmul.dsd_cache[key] = triton.kernel(src,
                                                device=a.device,
                                                defines=defines,
                                                num_warps=[4])
     kernel = _matmul.dsd_cache[key]
     # output
     CS0 = BS0
     CS1 = BS1
     CS2 = BS3 if trans_c else AS1
     CS3 = AS1 if trans_c else BS3
     locks = _matmul.get_locks(2 * BS0 * BS3 // 32 * num_locks, a.device)
     c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
     kernel(a.data_ptr(),
            b.data_ptr(),
            c.data_ptr(),
            block,
            b.stride(2),
            c.stride(2),
            a.stride(0),
            b.stride(0),
            c.stride(0),
            a.stride(1),
            b.stride(1),
            c.stride(1),
            BS3,
            AS1,
            0,
            0,
            lut.data_ptr(),
            locks.data_ptr(),
            num_locks,
            grid=lambda opt: [width, triton.cdiv(BS3, opt.TN), BS0])
     return c