def _dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, bench, time): # shapes / dtypes AS0 = spdims[0] AS1 = block * spdims[2 if trans_a else 1] AS2 = block * spdims[1 if trans_a else 2] BS0 = b.size(0) BS1 = b.size(1) BS2 = b.size(3 if trans_b else 2) BS3 = b.size(2 if trans_b else 3) dtype = a.dtype # kernel key = (block, a.dtype, b.dtype, trans_a, trans_b, trans_c) if key not in _sparse_matmul.dsd_cache: defines = { 'TM': block, 'TN': 128, 'TK': 8, 'TYPE': dtype, 'STRIDE_AM': 1 if trans_a else block, 'STRIDE_AK': block if trans_a else 1, 'STRIDE_BN': 'ldb' if trans_b else '1', 'STRIDE_BK': '1' if trans_b else 'ldb', 'STRIDE_CM': '1' if trans_c else 'ldc', 'STRIDE_CN': 'ldc' if trans_c else '1', 'NAME': 'dsd_kernel', 'DSD': True } _sparse_matmul.dsd_cache[key] = triton.kernel(src, defines=defines, num_warps=[4]) kernel = _sparse_matmul.dsd_cache[key] # output CS0 = BS0 CS1 = BS1 CS2 = BS3 if trans_c else AS1 CS3 = AS1 if trans_c else BS3 locks = _sparse_matmul.get_locks(2 * BS0 * BS3 // 32 * num_locks) c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device) time[0] = kernel( a, b, c, block, b.stride(2), c.stride(2), a.stride(0), b.stride(0), c.stride(0), a.stride(1), b.stride(1), c.stride(1), BS3, 0, 0, lut, locks, num_locks, grid=lambda opt: [width, triton.cdiv(BS3, opt.d('TN')), BS0], bench=bench) return c
def __init__(self, einsum, dtype, stride_a, stride_b, stride_c, shape_a, shape_b, shape_c, arrays): # parse symbols expr_a, expr_bc = einsum.split(",") expr_b, expr_c = expr_bc.split("->") subscripted = [] sym_a = _einsum.parse_expr(expr_a, subscripted) sym_b = _einsum.parse_expr(expr_b, subscripted) sym_c = _einsum.parse_expr(expr_c, subscripted) # parse axes axes_b, axes_m, axes_k = _einsum.parse_axes(sym_a, sym_b, sym_c, subscripted) _, axes_n, _ = _einsum.parse_axes(sym_b, sym_a, sym_c, subscripted) axes = axes_b + axes_m + axes_n + axes_k # check dimensions dims_a = dict(zip(sym_a, shape_a)) dims_b = dict(zip(sym_b, shape_b)) dims_c = dict(zip(sym_c, shape_c)) for axes in [axes_b, axes_k]: for d in axes: dim_a = dims_a[d] if d in sym_a else None dim_b = dims_b[d] if d in sym_b else None if dim_a and dim_b and dim_a != dim_b: raise ValueError(f'incompatible dimension {d}' f' (a: {dim_a}; b: {dim_b})') dims = dict() dims.update(dims_a) dims.update(dims_b) dims.update(dims_c) # look-up tables TK = 16 if dtype == triton.fw.torch.float16 else 8 arrays = [(x, arrays[x]) for x in subscripted] delta_a, lut_mode_a = _einsum.make_delta(axes_k, TK, stride_a, dims, sym_a, arrays) delta_b, lut_mode_b = _einsum.make_delta(axes_k, TK, stride_b, dims, sym_b, arrays) # hash for recompilation stride_a_multiple = max([x for x in [1, 2, 4, 8] if shape_a[-1] % x == 0]) stride_b_multiple = max([x for x in [1, 2, 4, 8] if shape_b[-1] % x == 0]) stride_c_multiple = max([x for x in [1, 2, 4, 8] if shape_c[-1] % x == 0]) name = f'{expr_a}_{expr_b}_{expr_c}_{lut_mode_a}_{lut_mode_b}'\ f'_{stride_a_multiple}_{stride_b_multiple}_{stride_c_multiple}' # recompile if necessary cache = _einsum.instance.kernel_cache if name not in cache: cachesize = len(cache) cache[name] = _einsum.make_kernel(f'__einsum{cachesize}', sym_a, sym_b, sym_c, axes_m, axes_n, axes_k, axes_b, stride_a_multiple, stride_b_multiple, stride_c_multiple, lut_mode_a, lut_mode_b, delta_a, delta_b, subscripted) self.kernel = cache[name] # Initialize locks if _einsum.instance.locks is None: _einsum.instance.locks = torch.zeros(2*1024*1024, dtype=torch.int32).cuda() # Kernel arguments dim_m = [dims[d] for d in axes_m] dim_n = [dims[d] for d in axes_n] dim_k = [dims[d] for d in axes_k] dim_b = [dims[d] for d in axes_b] M = reduce(mul, dim_m, 1) N = reduce(mul, dim_n, 1) K = reduce(mul, dim_k, 1) B = reduce(mul, dim_b, 1) stride_a = list(stride_a[:-1]) stride_b = list(stride_b[:-1]) stride_c = list(stride_c[:-1]) arrays = [torch.from_numpy(x).cuda() for _, x in arrays] alpha = 1. div_m = 1 self.args = [None, None, None, _einsum.instance.locks, alpha, M, N, K, div_m] +\ dim_m + dim_n + dim_k + dim_b +\ stride_a + stride_b + stride_c if lut_mode_a != _einsum.LUT_MODE.CONSTANT: delta_a = delta_a[0] if lut_mode_a == _einsum.LUT_MODE.SCALAR else torch.from_numpy(delta_a).cuda() self.args += [delta_a] if lut_mode_b != _einsum.LUT_MODE.CONSTANT: delta_b = delta_b[0] if lut_mode_b == _einsum.LUT_MODE.SCALAR else torch.from_numpy(delta_b).cuda() self.args += [delta_b] self.args += arrays self.grid = lambda opt: [triton.cdiv(M, opt.d('TM')) * triton.cdiv(N, opt.d('TN')), triton.cdiv(B, opt.d('TB')), opt.d('TZ')] # position of dynamic arguments self.pos_a = 0 self.pos_b = 1 self.pos_c = 2 # pre-processor macros TM = [16] + [x for x in [32, 64, 128] if x <= M] TN = [16] + [x for x in [32, 64, 128] if x <= N] TB = [x for x in [1, 2, 4] if x <= B] MAX_GZ = K // 2048 MIN_GM = M // max(TM) MIN_GN = N // max(TN) MIN_GB = B // max(TB) TZ = [x for x in [1, 2, 4, 8, 16, 32] \ if x < MAX_GZ and x*MIN_GM*MIN_GN*MIN_GB < 256] TZ = [1] if not TZ else [TZ[-1], TZ[-1]*2] self.macros = { 'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype } # information on compute self.dtype = dtype self.flops = 2 * B * M * N * K self.sym_a = sym_a self.sym_b = sym_b self.sym_c = sym_c # save equivalent mat-mul dimensions self.matmul_B = B self.matmul_M = M self.matmul_N = N self.matmul_K = K
def _dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs, bench, time): # shapes / dtypes AS0 = a.size(0) AS1 = a.size(1) AS2 = a.size(3 if trans_a else 2) AS3 = a.size(2 if trans_a else 3) BS0 = spdims[0] BS1 = block * spdims[2 if trans_b else 1] BS2 = block * spdims[1 if trans_b else 2] dtype = a.dtype # kernel key = (block, a.dtype, b.dtype, trans_a, trans_b, trans_c) if key not in _sparse_matmul.dds_cache: TM = [64, 128] if dtype == torch.float32 else [64, 128, 256] TK = [8] if dtype == torch.float32 else [16] defines = { 'TM': TM, 'TN': block, 'TK': TK, 'BLOCK': block, 'TYPE': dtype, 'STRIDE_AM': 1 if trans_a else 'lda', 'STRIDE_AK': 'lda' if trans_a else 1, 'STRIDE_BN': block if trans_b else 1, 'STRIDE_BK': 1 if trans_b else block, 'STRIDE_CM': '1' if trans_c else 'ldc', 'STRIDE_CN': 'ldc' if trans_c else '1', 'NAME': 'dds_kernel', 'DDS': True } _sparse_matmul.dds_cache[key] = triton.kernel(src, defines=defines, num_warps=[4]) kernel = _sparse_matmul.dds_cache[key] # output CS0 = AS0 CS1 = AS1 CS2 = BS2 if trans_c else AS2 CS3 = AS2 if trans_c else BS2 locks = _sparse_matmul.get_locks(2 * AS0 * AS2 // 32 * num_locks, a.device) c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device) time[0] = kernel( a, b, c, a.stride(2), block, c.stride(2), a.stride(0), b.stride(0), c.stride(0), a.stride(1), b.stride(1), c.stride(1), AS2, BS2, 0, 0, lut, locks, num_locks, grid=lambda opt: [width, triton.cdiv(AS2, opt.d('TM')), AS0], bench=bench) return c
def do_work(x, in_order, out_order): x_inner_mul = _permute.multiple_of(x.shape['NCHW'.index(in_order[-1])]) y_inner_mul = _permute.multiple_of(x.shape['NCHW'.index( out_order[-1])]) key = (x.dtype, in_order, out_order, x_inner_mul, y_inner_mul) if key not in _permute.kernels: TN = [32] if in_order[-1] == 'N' or out_order[-1] == 'N' else 1 TC = [32] if in_order[-1] == 'C' or out_order[-1] == 'C' else 1 THW = [32] if in_order[-1] == 'W' or out_order[-1] == 'W' else 1 defines = { 'NAME': f'permute_{in_order}_{out_order}_{x_inner_mul}_{y_inner_mul}', 'TYPE': x.dtype, # stride multiple for X 'M_STRIDE_XN': 1 if in_order[-1] == 'N' else x_inner_mul, 'M_STRIDE_XC': 1 if in_order[-1] == 'N' else x_inner_mul, 'M_STRIDE_XHW': 1 if in_order[-1] == 'N' else x_inner_mul, # stride multiple for Y 'M_STRIDE_YN': 1 if out_order[-1] == 'N' else y_inner_mul, 'M_STRIDE_YC': 1 if out_order[-1] == 'N' else y_inner_mul, 'M_STRIDE_YHW': 1 if out_order[-1] == 'N' else y_inner_mul, # strides for X 'STRIDE_XN': 1 if in_order[-1] == 'N' else 'stride_xn', 'STRIDE_XC': 1 if in_order[-1] == 'C' else 'stride_xc', 'STRIDE_XHW': 1 if in_order[-1] == 'W' else 'stride_xhw', # strides for Y 'STRIDE_YN': 1 if out_order[-1] == 'N' else 'stride_yn', 'STRIDE_YC': 1 if out_order[-1] == 'C' else 'stride_yc', 'STRIDE_YHW': 1 if out_order[-1] == 'W' else 'stride_yhw', # tile parameters 'TN': TN, 'TC': TC, 'THW': THW } _permute.kernels[key] = triton.kernel(src, defines=defines, num_warps=[4]) kernel = _permute.kernels[key] N, C, H, W = x.shape y = torch.empty_strided(x.shape, _permute.strides(N, C, H, W, out_order), device=x.device, dtype=x.dtype) stride_xn, stride_xc, _, stride_xhw = x.stride() stride_yn, stride_yc, _, stride_yhw = y.stride() grid = lambda opt: (triton.cdiv(N, opt.d('TN')), triton.cdiv(C, opt.d('TC')), triton.cdiv(H * W, opt.d('THW'))) kernel(x, y, N, C, H * W, stride_xn, stride_xc, stride_xhw, stride_yn, stride_yc, stride_yhw, grid=grid) return y
def _dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs): # shapes / dtypes AS0 = a.size(0) AS1 = a.size(1) AS2 = a.size(3 if trans_a else 2) AS3 = a.size(2 if trans_a else 3) BS0 = spdims[0] BS1 = block * spdims[2 if trans_b else 1] BS2 = block * spdims[1 if trans_b else 2] dtype = a.dtype # kernel key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c) if key not in _matmul.dds_cache: defines = { 'TM': 128, 'TN': block, 'TK': 16, 'BLOCK': block, 'TYPE': dtype, 'STRIDE_AM': 1 if trans_a else 'lda', 'STRIDE_AK': 'lda' if trans_a else 1, 'STRIDE_BN': block if trans_b else 1, 'STRIDE_BK': 1 if trans_b else block, 'STRIDE_CM': '1' if trans_c else 'ldc', 'STRIDE_CN': 'ldc' if trans_c else '1', 'NAME': 'dds_kernel', 'DDS': True } _matmul.dds_cache[key] = triton.kernel(src, device=a.device, defines=defines) kernel = _matmul.dds_cache[key] # output CS0 = AS0 CS1 = AS1 CS2 = BS2 if trans_c else AS2 CS3 = AS2 if trans_c else BS2 locks = _matmul.get_locks(2 * AS0 * AS2 // 32 * num_locks, a.device) c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device) kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), a.stride(2), block, c.stride(2), a.stride(0), b.stride(0), c.stride(0), a.stride(1), b.stride(1), c.stride(1), AS2, BS2, 0, 0, lut.data_ptr(), locks.data_ptr(), num_locks, grid=lambda opt: [width, triton.cdiv(AS2, opt.TM), AS0]) return c
def _dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs, bench, time): global triton if triton is None: triton = importlib.import_module('triton') # shapes / dtypes AS0 = spdims[0] AS1 = block * spdims[2 if trans_a else 1] AS2 = block * spdims[1 if trans_a else 2] BS0 = b.size(0) BS1 = b.size(1) BS2 = b.size(3 if trans_b else 2) BS3 = b.size(2 if trans_b else 3) dtype = a.dtype # kernel meta = { 'TM': block, 'TN': 128, 'TK': 16, 'BLOCK': block, 'TZ': 1, 'SDD': False, 'DSD': True, 'DDS': False } # output CS0 = BS0 CS1 = BS1 CS2 = BS3 if trans_c else AS1 CS3 = AS1 if trans_c else BS3 locks = _sparse_matmul.get_locks(2 * BS0 * BS3 // 32 * num_locks, a.device) c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device) grid = lambda meta: [width, triton.cdiv(BS3, meta['TN']), BS0] _kernel[grid](a, b, c, a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), c.stride(0), c.stride(1), c.stride(2), c.stride(3), BS3, AS1, 0, 0, lut, locks, num_locks, num_warps=4, **meta) return c
def forward(ctx, x, scale, rpe, key_padding_mask, attn_mask, kp_mask_mode, attn_mask_mode, spdims, block, lut, num_blocks, maxlut, bench, time): apply_scale = False if scale == 1.0 else True # handle None rpe if rpe is None: apply_rpe = False stride_zrpe, stride_hrpe, stride_srpe = 0, 0, 0 rpe = torch.empty(0, dtype=x.dtype, device=x.device) else: apply_rpe = True stride_zrpe, stride_hrpe, stride_srpe = rpe.stride(0), rpe.stride( 1), rpe.stride(2) # handle None key_padding_mask if key_padding_mask is None: apply_kp_mask = False stride_zkpm = 0 key_padding_mask = torch.empty(0, dtype=x.dtype, device=x.device) else: apply_kp_mask = True stride_zkpm = key_padding_mask.stride(0) # handle None attention_mask if attn_mask is None: apply_attn_mask = False stride_zattnm = 0 attn_mask = torch.empty(0, dtype=x.dtype, device=x.device) else: apply_attn_mask = True stride_zattnm = attn_mask.stride(0) # run kernel kernel = _sparse_softmax.make_kernel(fwd_kernels, softmax_fwd, maxlut * block, x.dtype, block, apply_scale, apply_rpe, apply_kp_mask, apply_attn_mask, kp_mask_mode, attn_mask_mode) M = x.shape[0] grid = lambda opt: [ triton.cdiv(spdims[0] * spdims[1] * block, opt.d('TM')), M ] # run kernel time[0] = kernel(x, scale, lut, rpe, key_padding_mask, attn_mask,\ num_blocks, maxlut,\ x.stride(0),\ stride_zrpe, stride_hrpe, stride_srpe,\ stride_zkpm, stride_zattnm,\ grid=grid, bench=bench) # save to context ctx.mark_dirty(x) ctx.save_for_backward(x, lut) ctx.spdims = spdims ctx.block = block ctx.maxlut = maxlut ctx.scale = scale ctx.apply_scale = apply_scale ctx.apply_rpe = apply_rpe ctx.apply_kp_mask = apply_kp_mask ctx.apply_attn_mask = apply_attn_mask ctx.kp_mask_mode = kp_mask_mode ctx.attn_mask_mode = attn_mask_mode return x
def _call(a, b, pad_d, pad_h, pad_w, stride_d, stride_h, stride_w, upsample_d, upsample_h, upsample_w, a_layout, b_layout, c_layout): # input shapes shape_a = list(triton.shape(a)) shape_b = list(triton.shape(b)) dim = len(shape_a) - 2 # indices an, ac, ad, ah, aw = [a_layout.find(x) for x in 'ncdhw'] bk, bc, bd, bh, bw = [b_layout.find(x) for x in 'kctrs'] cn, ck, cd, ch, cw = [c_layout.find(x) for x in 'nkdhw'] # extract shapes if dim == 2: shape_a.insert(ad, 1) if dim == 2: shape_b.insert(bd, 1) # output shape shape_c = [0] * 5 shape_c[cn] = shape_a[an] shape_c[ck] = shape_b[bk] shape_c[cd] = (shape_a[ad] * upsample_d - shape_b[bd] + 1 + 2 * pad_d + stride_d - 1) // stride_d shape_c[ch] = (shape_a[ah] * upsample_h - shape_b[bh] + 1 + 2 * pad_h + stride_h - 1) // stride_h shape_c[cw] = (shape_a[aw] * upsample_w - shape_b[bw] + 1 + 2 * pad_w + stride_w - 1) // stride_w # strides stride_a = _conv._extract_strides(shape_a) stride_b = _conv._extract_strides(shape_b) stride_c = _conv._extract_strides(shape_c) # tiling parameters TM = [32] TN = [32] TK = 8 # pointer deltas for a delta_a = _conv._delta_a(upsample_d, upsample_h, upsample_w, bc, bd, bh, bw, ac, ad, ah, aw, stride_a, shape_b, TK) delta_a = triton.fw.torch.from_numpy(delta_a).cuda() # delta increments for a inc_a = np.arange(delta_a.shape[-1] - TK, dtype=np.int32) inc_a = ((inc_a + TK) % inc_a.size) - inc_a inc_a = triton.fw.torch.from_numpy(inc_a).cuda() # allocate output if dim == 2: shape_c.pop(cd) c = triton.empty(shape_c, dtype=a.dtype) if dim == 2: shape_c.insert(cd, 1) # execute kernel trans_b = False is_wgrad = False is_blut = False macros = { 'UPAR': 'stride_h' if is_wgrad else '1', 'UPAS': '******' if is_wgrad else '1', 'UPAH': '' if is_wgrad else 'stride_h', 'UPAW': '' if is_wgrad else 'stride_w', 'LUT_SIZE': delta_a.shape[-1], 'TM': TM, 'TN': TN, 'TK': TK, 'A_TYPE': 'float', 'B_TYPE': 'float' } MATMUL_M = shape_c[cn] * shape_c[cd] * shape_c[ch] * shape_c[cw] MATMUL_N = shape_c[ck] MATMUL_K = shape_b[bc] * shape_b[bd] * shape_b[bh] * shape_b[bw] _conv.kernel( a, b, c, # matrix multiplication shapes MATMUL_M, MATMUL_N, MATMUL_K, # shapes for a shape_a[ah], shape_a[aw], # shapes for b shape_b[bh], shape_b[bw], # chapes for c shape_c[ch], shape_c[cw], shape_c[cn], # strides for a stride_a[an], stride_a[ac], stride_a[ad + 0], stride_a[ad + 1], stride_a[ad + 2], # strides for b stride_b[bc], stride_b[bd + 0], stride_b[bd + 1], stride_b[bd + 2], stride_b[bk], # strides for c stride_c[cn], stride_c[ck], stride_c[cd], stride_c[cd + 1], stride_c[cd + 2], # padding pad_h, pad_w, # striding stride_h, stride_w, # upsampling upsample_h, upsample_w, 0, 0, 0, 0, 0, 0, # look-up table delta_a, inc_a, lambda opt: [ triton.cdiv(MATMUL_M, opt.d('TM')), triton.cdiv(MATMUL_N, opt.d('TN')) ], **macros) return c
def _dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs): # shapes / dtypes AS0 = spdims[0] AS1 = block * spdims[2 if trans_a else 1] AS2 = block * spdims[1 if trans_a else 2] BS0 = b.size(0) BS1 = b.size(1) BS2 = b.size(3 if trans_b else 2) BS3 = b.size(2 if trans_b else 3) dtype = a.dtype # kernel key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c) if key not in _matmul.dsd_cache: TN = [64, 128] if dtype == torch.float32 else [64, 128] TK = [8] if dtype == torch.float32 else [16] defines = { 'TM': block, 'TN': TN, 'TK': TK, 'BLOCK': block, 'TYPE': dtype, 'STRIDE_AM': 1 if trans_a else block, 'STRIDE_AK': block if trans_a else 1, 'STRIDE_BN': 'ldb' if trans_b else '1', 'STRIDE_BK': '1' if trans_b else 'ldb', 'STRIDE_CM': '1' if trans_c else 'ldc', 'STRIDE_CN': 'ldc' if trans_c else '1', 'NAME': 'dsd_kernel', 'DSD': True } _matmul.dsd_cache[key] = triton.kernel(src, device=a.device, defines=defines, num_warps=[4]) kernel = _matmul.dsd_cache[key] # output CS0 = BS0 CS1 = BS1 CS2 = BS3 if trans_c else AS1 CS3 = AS1 if trans_c else BS3 locks = _matmul.get_locks(2 * BS0 * BS3 // 32 * num_locks, a.device) c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device) kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), block, b.stride(2), c.stride(2), a.stride(0), b.stride(0), c.stride(0), a.stride(1), b.stride(1), c.stride(1), BS3, AS1, 0, 0, lut.data_ptr(), locks.data_ptr(), num_locks, grid=lambda opt: [width, triton.cdiv(BS3, opt.TN), BS0]) return c