def forward(ctx, x): M, N = x.shape ldx = N dtype = x.dtype y = torch.empty((M, N)).cuda() defines = { 'TYPE': dtype, 'TM': [32, 64, 128], 'TN': [32, 64, 128], } grid = lambda opt: [ triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN')) ] if _copy.kernel is None: _copy.kernel = triton.kernel(_copy.src, defines=defines, num_warps=[4]) _copy.kernel(x, y, M, N, ldx, grid=grid) return y
def make_kernel(device, dtype): key = (device, dtype) cache = make_kernel.cache if key not in cache: defines = {'TYPE': dtype} cache[key] = triton.kernel(src, device=device, defines=defines, autotune_vals=autotune_configs, autotune_key=autotune_key) return cache[key]
def forward(ctx, x, gamma, beta, eps): # lazy compilation of kernel if _batchnorm.fwd_kernel is None: _batchnorm.fwd_kernel = triton.kernel(fwd_src, defines={'TM': 128}) # shapes shape = triton.shape(x) dtype = x.dtype # allocate outputs C, H, W, B = shape[0], shape[1], shape[2], shape[3] y = triton.empty(shape, dtype=dtype) mean = triton.empty([C], dtype=dtype) var = triton.empty([C], dtype=dtype) # execute kernels _batchnorm.fwd_kernel(y, mean, var, x, gamma, beta, H * W * B, eps, grid=lambda opt: [1, C]) # save ctx.save_for_backward(x, gamma, beta, mean, var) ctx.eps = eps return y
class _add(torch.autograd.Function): src = """ __global__ void add(float* z, float* x, float* y, int N) { int pid = get_program_id(0); int offset[TILE] = pid * TILE + 0 ... TILE; float* pz[TILE] = z + offset; float* px[TILE] = x + offset; float* py[TILE] = y + offset; bool check[TILE] = offset < N; *?(check)pz = *?(check)px + *?(check)py; } """ kernel = triton.kernel(src, defines={'TILE': 1024}, num_warps=[4]) @staticmethod def forward(ctx, x, y): z = torch.empty_like(x).cuda() N = x.numel() grid = lambda opt: (triton.cdiv(N, opt.d('TILE')), ) _add.kernel(z, x, y, N, grid=grid) return z
def get_kernel(block, dtype, device): key = (block, dtype, device) if key not in kernels: src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c')) defines = {'BLOCK': block, 'TYPE': dtype} kernels[key] = triton.kernel(src, device=device, defines=defines) return kernels[key]
def forward(cls, ctx, logits, indices): n_vocab = logits.shape[-1] assert indices.dtype == torch.int64 assert ( n_vocab % 128 == 0), "Number of logit options must be divisible by 128." if not (logits.dtype, n_vocab) in cls.input_config_to_kernel_fwd: cls.input_config_to_kernel_fwd[(logits.dtype, n_vocab)] = triton.kernel( cls.fwd_src, device=logits.device, defines={ "TILE": n_vocab, "TYPE": logits.dtype, }, num_warps=[4], ) kernel_fwd = cls.input_config_to_kernel_fwd[(logits.dtype, n_vocab)] result = torch.zeros_like(indices, dtype=logits.dtype).cuda() grid = lambda opt: (logits.shape[0], ) kernel_fwd(logits.data_ptr(), indices.data_ptr(), result.data_ptr(), grid=grid) # logits -> neg_logprobs via an in place modification by kernel_fwd ctx.save_for_backward(logits, indices) return result
def make_kernel(cache, src, max_k, dtype, block, apply_scale, apply_rpe, apply_kp_mask, apply_attn_mask, kp_mask_mode, attn_mask_mode): if max_k >= 32768: raise NotImplementedError('Reductions larger than 32768 elements '\ 'are not yet implemented') num_warps = 4 if max_k < 512 else (8 if max_k < 2048 else 16) pad = num_warps * 32 * 2 TN = (int(max_k) + pad-1)//pad * pad # just-in-time compile kernel key = (block, dtype, num_warps, TN, apply_scale, apply_rpe, apply_kp_mask, apply_attn_mask, kp_mask_mode, attn_mask_mode) if key not in cache: defines = {'TM': [1], 'TN': [TN], 'TYPE': dtype, 'BLOCK': block, 'INFINITY': {torch.float32: 'F32_INFINITY', torch.float16: 'F16_INFINITY'}[dtype]} if apply_scale: defines['APPLY_SCALE'] = True if apply_rpe: defines['APPLY_RPE'] = True if apply_kp_mask: defines['APPLY_KP_MASK'] = True if kp_mask_mode == 'mul': defines['KP_MASK_MUL'] = True if apply_attn_mask: defines['APPLY_ATTN_MASK'] = True if attn_mask_mode == 'mul': defines['ATTN_MASK_MUL'] = True kernel = triton.kernel(src, defines=defines, num_warps=[num_warps]) cache[key] = kernel return cache[key]
def make_kernel(cache, src, max_k, dtype, block, kp_mask_mode, attn_mask_mode): # pad tile to cover the entire reduction params = {16384: (1, 32768, 16), 8192: (1, 16384, 16), 4096: (1, 8192, 16), 2048: (1, 4096, 16), 1024: (1, 2048, 16), 512: (1, 1024, 8), 256: (1, 512, 4), 128: (1, 256, 4)} bound = max(128, 2**int(math.log2(max_k-1))) if bound not in params: raise NotImplementedError('Reductions larger than 32768 elements '\ 'are not yet implemented') TM, TN, num_warps = params[bound] # just-in-time compile kernel key = (dtype, TM, TN, num_warps, kp_mask_mode, attn_mask_mode) if key not in cache: defines = {'TM': [TM], 'TN': [TN], 'TYPE': dtype, 'BLOCK': block, 'INFINITY': {torch.float32: 'F32_INFINITY', torch.float16: 'F16_INFINITY'}[dtype]} if kp_mask_mode == 'mul': defines['KP_MASK_MUL'] = True if attn_mask_mode == 'mul': defines['ATTN_MASK_MUL'] = True kernel = triton.kernel(src, defines=defines, num_warps=[num_warps]) cache[key] = kernel return cache[key]
def backward(ctx, dy): # lazy compilation of kernel if _batchnorm.bwd_kernel is None: _batchnorm.bwd_kernel = triton.kernel(bwd_src, defines={'TN': 128}) # retrieve info x, gamma, beta, mean, var = ctx.saved_tensors eps = ctx.eps # allocate result dx = triton.empty(triton.shape(x), dtype=x.dtype) dgamma = triton.empty(triton.shape(gamma), dtype=gamma.dtype) dbeta = triton.empty(triton.shape(beta), dtype=beta.dtype) # execute C, H, W, B = triton.shape(x) _batchnorm.bwd_kernel(dx, dgamma, dbeta, dy, x, gamma, mean, var, H * W * B, eps, grid=lambda opt: [1, C]) return dx, dgamma, dbeta, None
def _dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs): # shapes / dtypes AS0 = spdims[0] AS1 = block * spdims[2 if trans_a else 1] AS2 = block * spdims[1 if trans_a else 2] BS0 = b.size(0) BS1 = b.size(1) BS2 = b.size(3 if trans_b else 2) BS3 = b.size(2 if trans_b else 3) dtype = a.dtype # kernel key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c) if key not in _matmul.dsd_cache: defines = { 'TM': block, 'TN': 128, 'TK': 16, 'BLOCK': block, 'TYPE': dtype, 'STRIDE_AM': 1 if trans_a else block, 'STRIDE_AK': block if trans_a else 1, 'STRIDE_BN': 'ldb' if trans_b else '1', 'STRIDE_BK': '1' if trans_b else 'ldb', 'STRIDE_CM': '1' if trans_c else 'ldc', 'STRIDE_CN': 'ldc' if trans_c else '1', 'NAME': 'dsd_kernel', 'DSD': True } _matmul.dsd_cache[key] = triton.kernel(src, device=a.device, defines=defines) kernel = _matmul.dsd_cache[key] # output CS0 = BS0 CS1 = BS1 CS2 = BS3 if trans_c else AS1 CS3 = AS1 if trans_c else BS3 locks = _matmul.get_locks(2 * BS0 * BS3 // 32 * num_locks, a.device) c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device) kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), block, b.stride(2), c.stride(2), a.stride(0), b.stride(0), c.stride(0), a.stride(1), b.stride(1), c.stride(1), BS3, AS1, 0, 0, lut.data_ptr(), locks.data_ptr(), num_locks, grid=lambda opt: [width, triton.cdiv(BS3, opt.TN), BS0]) return c
def _sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, luts, num_locks, widths, packs): if trans_c: a, b = b, a trans_a, trans_b = not trans_b, not trans_a AS0 = a.size(0) AS1 = a.size(1) AS2 = a.size(3 if trans_a else 2) AS3 = a.size(2 if trans_a else 3) BS0 = b.size(0) BS1 = b.size(1) BS2 = b.size(3 if trans_b else 2) BS3 = b.size(2 if trans_b else 3) dtype = a.dtype device = a.device is_16_multiple = AS3 % 16 == 0 is_32_multiple = AS3 % 32 == 0 is_64_multiple = AS3 % 64 == 0 if not is_16_multiple: raise ValueError('Reduction size for SDD must be a multiple of 16') # create kernel total_width = sum([width*pack*pack for width,pack in zip(widths, packs)]) c = torch.empty((AS0, total_width, block, block), dtype=dtype, device=device) for lut, width, pack in zip(luts, widths, packs): num_lock = 1 key = (block, device, a.dtype, b.dtype, trans_a, trans_b, trans_c, pack, is_32_multiple, is_64_multiple) if key not in _matmul.sdd_cache: defines = {'TM': block*pack, 'TN': block*pack, 'TMN': block*block*pack*pack, 'BLOCK': block, 'TK': 32, 'TYPE': dtype, 'STRIDE_AM': '1' if trans_a else 'lda', 'STRIDE_AK': 'lda' if trans_a else '1', 'STRIDE_BN': 'ldb' if trans_b else '1', 'STRIDE_BK': '1' if trans_b else 'ldb', 'STRIDE_CM': 'ldc', 'STRIDE_CN': '1', 'SDD': True, 'TZ': 1, 'NAME': 'sdd_kernel'} _matmul.sdd_cache[key] = triton.kernel(src, device=device, defines=defines) kernel = _matmul.sdd_cache[key] # create output locks = _matmul.get_locks(2*width*AS0*num_lock, a.device) # maximum grid size is 65535 # so operation might be decomposed into multiple # kernel calls max_width = 49152 for off_width in range(0, width, max_width): kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), a.stride(2), b.stride(2), block, a.stride(0), b.stride(0), c.stride(0), a.stride(1), b.stride(1), c.stride(0), AS2, AS2, AS3, off_width, lut.data_ptr(), locks.data_ptr(), num_lock, grid = lambda opt: [opt.TZ, min(max_width, width - off_width), AS0]) # save for backward pass return c
def forward(cls, ctx, logits, indices): # make sure we can use triton n_vocab = logits.shape[-1] assert (indices.dtype == torch.int64 ), "Indices are expected to be of type long." # compile a new kernel if needed; otherwise load from a cache if not (logits.dtype, n_vocab) in cls.input_config_to_kernel_fwd: infinities = { torch.float16: "F16_INFINITY", torch.float32: "F32_INFINITY", } cls.input_config_to_kernel_fwd[( logits.dtype, n_vocab)] = triton.kernel( cls.fwd_src, device=logits.device, defines={ "TILE": make_power_of_two(n_vocab), "TYPE": logits.dtype, "INFINITY": infinities[logits.dtype], }, ) kernel_fwd = cls.input_config_to_kernel_fwd[(logits.dtype, n_vocab)] # flatten logits and be prepared to restore them to their original shape original_logits_shape = logits.shape if len(original_logits_shape) > 2: logits = logits.reshape((-1, n_vocab)) indices = indices.reshape((-1, )) # run the kernel and assign the result in place result = torch.empty_like(indices, dtype=logits.dtype).cuda() neg_logprobs = torch.empty_like(logits, dtype=logits.dtype).cuda() grid = lambda opt: (logits.shape[0], ) kernel_fwd( logits.data_ptr(), neg_logprobs.data_ptr(), indices.data_ptr(), result.data_ptr(), n_vocab, grid=grid, ) if len(original_logits_shape) > 2: logits = logits.reshape(original_logits_shape) indices = indices.reshape(*original_logits_shape[:-1]) ctx.save_for_backward(neg_logprobs, indices) ctx.original_logits_shape = original_logits_shape return result
def forward(ctx, x, scale, bias, res): if x.dtype not in _relu.fwd_kernel: defines = {'TYPE': x.dtype, 'TN': [128]} _relu.fwd_kernel[x.dtype] = triton.kernel(_relu.fwd_src, defines=defines, num_warps=[4]) kernel = _relu.fwd_kernel[x.dtype] # launch kernel y = torch.empty_strided(x.shape, x.stride(), device=x.device, dtype=x.dtype) N = x.numel() grid = lambda opt: [triton.cdiv(N, opt.d('TN'))] kernel(x, y, scale.item(), bias.item(),res, N, grid=grid) # update context ctx.save_for_backward(x, y) ctx.scale = scale return y
def _sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, bench, time): if trans_c: a, b = b, a trans_a, trans_b = not trans_b, not trans_a AS0 = a.size(0) AS1 = a.size(1) AS2 = a.size(3 if trans_a else 2) AS3 = a.size(2 if trans_a else 3) BS0 = b.size(0) BS1 = b.size(1) BS2 = b.size(3 if trans_b else 2) BS3 = b.size(2 if trans_b else 3) dtype = a.dtype # create kernel key = (block, a.dtype, b.dtype, trans_a, trans_b, trans_c) if key not in _sparse_matmul.sdd_cache: defines = {'TM': block, 'TN': block, 'TK': 16, 'TYPE': dtype, 'STRIDE_AM': '1' if trans_a else 'lda', 'STRIDE_AK': 'lda' if trans_a else '1', 'STRIDE_BN': 'ldb' if trans_b else '1', 'STRIDE_BK': '1' if trans_b else 'ldb', 'STRIDE_CM': 'ldc', 'STRIDE_CN': '1', 'SDD': True, 'TZ': 1, 'NAME': 'sdd_kernel'} _sparse_matmul.sdd_cache[key] = triton.kernel(src, defines=defines, num_warps=[1, 2, 4]) kernel = _sparse_matmul.sdd_cache[key] # create output locks = _sparse_matmul.get_locks(2*width*AS0*num_locks) c = torch.empty((AS0, width, block, block), dtype=dtype, device=a.device) # maximum grid size is 65535 # so operation might be decomposed into multiple # kernel calls max_width = 49152 total = 0 if bench else None for off_width in range(0, width, max_width): current = kernel(a, b, c, a.stride(2), b.stride(2), block, a.stride(0), b.stride(0), c.stride(0), a.stride(1), b.stride(1), c.stride(0), AS2, AS3, off_width, lut, locks, num_locks, grid = lambda opt: [opt.d('TZ'), min(max_width, width - off_width), AS0], bench = bench) total = total + current if bench else None time[0] = total # save for backward pass return c
def _dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs, bench, time): # shapes / dtypes AS0 = spdims[0] AS1 = block * spdims[2 if trans_a else 1] AS2 = block * spdims[1 if trans_a else 2] BS0 = b.size(0) BS1 = b.size(1) BS2 = b.size(3 if trans_b else 2) BS3 = b.size(2 if trans_b else 3) dtype = a.dtype # kernel key = (block, a.dtype, b.dtype, trans_a, trans_b, trans_c) if key not in _sparse_matmul.dsd_cache: TN = [64, 128] if dtype == torch.float32 else [64, 128] TK = [8] if dtype == torch.float32 else [16] defines = {'TM': block, 'TN': TN, 'TK': TK, 'BLOCK': block, 'TYPE': dtype, 'STRIDE_AM': 1 if trans_a else block, 'STRIDE_AK': block if trans_a else 1, 'STRIDE_BN': 'ldb' if trans_b else '1', 'STRIDE_BK': '1' if trans_b else 'ldb', 'STRIDE_CM': '1' if trans_c else 'ldc', 'STRIDE_CN': 'ldc' if trans_c else '1', 'NAME': 'dsd_kernel', 'DSD': True} _sparse_matmul.dsd_cache[key] = triton.kernel(src, defines=defines, num_warps=[4]) kernel = _sparse_matmul.dsd_cache[key] # output CS0 = BS0 CS1 = BS1 CS2 = BS3 if trans_c else AS1 CS3 = AS1 if trans_c else BS3 locks = _sparse_matmul.get_locks(2*BS0*BS3//32*num_locks, a.device) c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device) time[0] = kernel(a, b, c, block, b.stride(2), c.stride(2), a.stride(0), b.stride(0), c.stride(0), a.stride(1), b.stride(1), c.stride(1), BS3, AS1, 0, 0, lut, locks, num_locks, grid = lambda opt: [width, triton.cdiv(BS3, opt.d('TN')), BS0], bench = bench) return c
def backward(ctx, dy): # load from context x, y = ctx.saved_tensors # get kernel if x.dtype not in _relu.bwd_kernel: defines = {'TYPE': x.dtype, 'TN': [128]} _relu.bwd_kernel[x.dtype] = triton.kernel(_relu.bwd_src, defines=defines, num_warps=[4]) kernel = _relu.bwd_kernel[x.dtype] # allocate output dx = torch.empty_strided(x.shape, x.stride(), device=x.device, dtype=x.dtype) dres = torch.empty_strided(x.shape, x.stride(), device=x.device, dtype=x.dtype) dscale = torch.zeros((1,), device=dy.device, dtype=torch.float32) dbias = torch.zeros_like(dscale) # launch kernel N = x.numel() grid = lambda opt: [triton.cdiv(N, opt.d('TN'))] kernel(x, y, ctx.scale.item(), dx, dy, dscale, dbias, dres, N, grid=grid) return dx, dscale.type(x.dtype), dbias.type(x.dtype), dres
def backward(cls, ctx, dneg_logprobs): """We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k] so we initialize the gradient as neg_logprobs, so we can just exponentiate to get p[k], which is most of what we need... neg_logprobs will be modified in place to become the gradient we want """ # load saved tensors and ensure correct types neg_logprobs, indices = ctx.saved_tensors original_logits_shape = ctx.original_logits_shape assert ( dneg_logprobs.dtype == neg_logprobs.dtype ), f"Backward flowing derivatives of type {dneg_logprobs.dtype} != logits type {neg_logprobs.dtype}" n_vocab = neg_logprobs.shape[-1] # generate or load kernel if not (neg_logprobs.dtype, n_vocab) in cls.input_config_to_kernel_bwd: cls.input_config_to_kernel_bwd[( neg_logprobs.dtype, n_vocab)] = triton.kernel( cls.bwd_src, device=neg_logprobs.device, defines={ "TILE": make_power_of_two(n_vocab), "TYPE": neg_logprobs.dtype, }, ) kernel_bwd = cls.input_config_to_kernel_bwd[(neg_logprobs.dtype, n_vocab)] grid = lambda opt: (neg_logprobs.shape[0], ) # neg_logprobs will be modified in place to become our gradient: kernel_bwd( neg_logprobs.data_ptr(), indices.data_ptr(), dneg_logprobs.data_ptr(), n_vocab, grid=grid, ) # reshape results based on shape of original logits passed to forward if len(original_logits_shape) > 2: neg_logprobs = neg_logprobs.reshape(original_logits_shape) return neg_logprobs, torch.zeros_like(indices)
def make_kernel(N, device): cache = make_kernel.cache # Now are kernels are indexed not only by the provided device but also # by the rounded number of columns in the input matrix BLOCK = next_power_of_2(N) # Another trick we can use is to ask the compiler to parallelize each # row-normalization more aggressively -- i.e., with more warps -- vectors # that are longer # You will see in the next tutorial how to auto-tune this value in a more natural # way so you don't have to come up with manual heuristics yourself num_warps = 4 if BLOCK >= 2048: num_warps = 8 if BLOCK >= 4096: num_warps = 16 # Each (BLOCK, num_warps, device) results in a different kernel key = (BLOCK, num_warps, device) if key not in cache: defines = {'BLOCK': BLOCK} cache[key] = triton.kernel(_src, device=device, defines=defines, num_warps=num_warps) return cache[key]
def _dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, bench, time): # shapes / dtypes AS0 = a.size(0) AS1 = a.size(1) AS2 = a.size(3 if trans_a else 2) AS3 = a.size(2 if trans_a else 3) BS0 = spdims[0] BS1 = block * spdims[2 if trans_b else 1] BS2 = block * spdims[1 if trans_b else 2] dtype = a.dtype # kernel key = (block, a.dtype, b.dtype, trans_a, trans_b, trans_c) if key not in _sparse_matmul.dds_cache: defines = {'TM': 128, 'TN': block, 'TK': 8, 'TYPE': dtype, 'STRIDE_AM': 1 if trans_a else 'lda', 'STRIDE_AK': 'lda' if trans_a else 1, 'STRIDE_BN': block if trans_b else 1, 'STRIDE_BK': 1 if trans_b else block, 'STRIDE_CM': '1' if trans_c else 'ldc', 'STRIDE_CN': 'ldc' if trans_c else '1', 'NAME': 'dds_kernel', 'DDS': True} _sparse_matmul.dds_cache[key] = triton.kernel(src, defines=defines, num_warps=[4]) kernel = _sparse_matmul.dds_cache[key] # output CS0 = AS0 CS1 = AS1 CS2 = BS2 if trans_c else AS2 CS3 = AS2 if trans_c else BS2 locks = _sparse_matmul.get_locks(2*AS0*AS2//32*num_locks) c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device) time[0] = kernel(a, b, c, a.stride(2), block, c.stride(2), a.stride(0), b.stride(0), c.stride(0), a.stride(1), b.stride(1), c.stride(1), AS2, 0, 0, lut, locks, num_locks, grid = lambda opt: [width, triton.cdiv(AS2, opt.d('TM')), AS0], bench = bench) return c
def _call(a, b): # create kernel if necessary dtype = a.dtype if dtype not in _dot.kernel: defines = { 'TYPE': dtype, 'STRIDE_AM': '1', 'STRIDE_AK': 'lda', 'STRIDE_BN': '1', 'STRIDE_BK': 'ldb', 'TM': [64, 128], 'TN': [64, 128], 'TK': [8, 16], 'TZ': [1] } _dot.kernel[dtype] = triton.kernel(_dot.src, num_warps=[4], defines=defines) kernel = _dot.kernel[dtype] # allocate output M, K = a.shape K, N = b.shape c = triton.empty([M, N], dtype=dtype) # enqueue grid = lambda opt: [ triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN')) ] time = kernel(a, b, c, 1., M, N, K, a.stride(0), b.stride(0), c.stride(0), grid=grid, bench=100) print(2 * M * N * K / (time * 1e-6) * 1e-9) return c
def backward(cls, ctx, dneg_logprobs): """We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k] so we initialize the gradient as neg_logprobs, so we can just exponentiate to get p[k], which is most of what we need... neg_logprobs will be modified in place to become the gradient we want """ neg_logprobs, indices = ctx.saved_tensors assert indices.dtype == torch.int64 assert ( dneg_logprobs.dtype == neg_logprobs.dtype ), f"Backward flowing derivatives of type {dneg_logprobs.dtype} != logits type {neg_logprobs.dtype}" n_vocab = neg_logprobs.shape[-1] N = neg_logprobs.numel() if not (neg_logprobs.dtype, n_vocab) in cls.input_config_to_kernel_bwd: cls.input_config_to_kernel_bwd[(neg_logprobs.dtype, n_vocab)] = triton.kernel( cls.bwd_src, device=neg_logprobs.device, defines={ "TILE": n_vocab, "TYPE": neg_logprobs.dtype }, num_warps=[4], ) kernel_bwd = cls.input_config_to_kernel_bwd[(neg_logprobs.dtype, n_vocab)] grid = lambda opt: (triton.cdiv(N, opt.TILE), ) # neg_logprobs will be modified in place to become our gradient: kernel_bwd( neg_logprobs.data_ptr(), indices.data_ptr(), dneg_logprobs.data_ptr(), grid=grid, ) return neg_logprobs, torch.zeros_like(indices)
def forward(ctx, x, running_mean, running_var, gamma, beta, training, momentum, eps): N, C, H, W = x.shape # lazy compilation of kernel key = (training, x.dtype) if key not in _batchnorm.fwd_kernel: defines = {'TM': 256, 'TYPE': x.dtype} if training: defines['TRAINING'] = True _batchnorm.fwd_kernel[key] = triton.kernel(_batchnorm.fwd_src, defines=defines, num_warps=[4]) kernel = _batchnorm.fwd_kernel[key] # allocate outputs y = torch.empty_strided(x.shape, x.stride(), layout=x.layout, dtype=x.dtype, device=x.device) mean = torch.empty(C, dtype=torch.float32, device=x.device) var = torch.empty(C, dtype=torch.float32, device=x.device) # execute kernels grid = lambda opt: [C] kernel(y, mean, var, running_mean, running_var, x, gamma, beta, H * W * N, momentum, eps, grid=grid) # save ctx.save_for_backward(x, gamma, beta, mean, var) ctx.eps = eps return y
def make_kernel(device, dtype, n_cols, cache, name): rounded = next_power_of_2(n_cols) div = largest_pow2_divisor(n_cols) key = (dtype, rounded, div) if key not in cache: fname = os.path.join(os.path.dirname(__file__), "cross_entropy.c") src = triton.read(fname, kernel_names=[name]) infinities = { torch.float16: "F16_INFINITY", torch.float32: "F32_INFINITY", } defines = { "TILE": rounded, "TYPE": dtype, "INFINITY": infinities[dtype], "N_COLS_MULT": div } cache[key] = triton.kernel(src, device=device, defines=defines, num_warps=4) return cache[key]
def backward(ctx, dy): # lazy compilation of kernel key = (dy.dtype, ) if key not in _batchnorm.bwd_kernel: _batchnorm.bwd_kernel[key] = triton.kernel(_batchnorm.bwd_src, defines={ 'TM': 256, 'TYPE': dy.dtype }, num_warps=[4]) kernel = _batchnorm.bwd_kernel[key] # retrieve info x, gamma, beta, mean, var = ctx.saved_tensors eps = ctx.eps # allocate result dx = torch.empty_strided(x.shape, x.stride(), layout=x.layout, dtype=x.dtype, device=x.device) dgamma = torch.empty_like(gamma) dbeta = torch.empty_like(beta) # execute N, C, H, W = x.shape kernel(dx, dgamma, dbeta, dy, x, gamma, mean, var, H * W * N, eps, grid=lambda opt: [C]) return dx, None, None, dgamma, dbeta, None, None, None
def forward(ctx, a, b, pad, stride): # create kernel if necessary dtype = a.dtype device = a.device # shapes Z, CI, H, W = a.shape _, R, S, CO = b.shape P = (H + 2 * pad[0] - R) // stride[0] + 1 Q = (W + 2 * pad[1] - S) // stride[1] + 1 # compile kernel if (dtype, device) not in _conv.kernel: TK = 16 defines = { 'TYPE': dtype, 'TM': [32, 64, 128], 'TN': [32, 64, 128], 'TK': [TK], 'TZ': [1], 'HH': H, 'WW': W, 'PP': P, 'QQ': Q, 'SS': S, 'RR': R, } idx = torch.arange(CI * R * S) ci, r, s = _conv.unpack(idx, CI, R, S) nci, nr, ns = _conv.unpack(idx + TK, CI, R, S) delta = (nci - ci) * a.stride(1) + (nr - r) * a.stride(2) + ( ns - s) * a.stride(3) delta = delta.type(torch.int32).cuda() _conv.kernel[dtype] = (delta, triton.kernel(_conv.src, device=device, num_warps=[4], defines=defines)) delta, kernel = _conv.kernel[dtype] # allocate output c = torch.empty([Z, CO, P, Q], dtype=dtype, device=device) # enqueue kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), 1., Z * P * Q, CO, CI * R * S, pad[0], pad[1], stride[0], stride[1], delta.data_ptr(), a.stride(0), a.stride(1), a.stride(2), a.stride(3), b.stride(0), b.stride(1), b.stride(2), b.stride(3), c.stride(0), c.stride(1), c.stride(2), c.stride(3), grid=lambda opt: [triton.cdiv(Z * P * Q, opt.TM), triton.cdiv(CO, opt.TN)]) return c
class _dot(triton.function): src = """ void dot(TYPE * A __noalias __readonly __aligned(16), TYPE * B __noalias __readonly __aligned(16), TYPE * C, float alpha, int M, int N, int K, int lda __multipleof(8), int ldb __multipleof(8), int ldc) { // prologue int ridx = get_program_id(0); int ridy = get_program_id(1); int rm[TM] = ridx * TM + 0 ... TM; int rn[TN] = ridy * TN + 0 ... TN; int rk[TK] = 0 ... TK; // pointers to operands TYPE* pa[SHAPE_A] = A + rk[BROADCAST_AK] * STRIDE_AK + rm[BROADCAST_AM] * STRIDE_AM; TYPE* pb[SHAPE_B] = B + rk[BROADCAST_BK] * STRIDE_BK + rn[BROADCAST_BN] * STRIDE_BN; // prefetches operands bool checka[SHAPE_A] = rk[BROADCAST_AK] < K; bool checkb[SHAPE_B] = rk[BROADCAST_BK] < K; TYPE a[SHAPE_A] = checka ? *pa : 0; TYPE b[SHAPE_B] = checkb ? *pb : 0; // reduction loop float c[TM, TN] = 0; for(int k = K; k > 0; k -= TK){ c += USE_A @ USE_B; bool checka[SHAPE_A] = k > TK; bool checkb[SHAPE_B] = k > TK; pa += TK * STRIDE_AK; pb += TK * STRIDE_BK; a = *?(checka)pa; b = *?(checkb)pb; } //c = c * alpha; // epilogue int rxm[TM] = get_program_id(0) * TM + 0 ... TM; int rxn[TN] = get_program_id(1) * TN + 0 ... TN; TYPE* pc[TM, TN] = C + rxm[:, newaxis] * ldc + rxn[newaxis, :]; bool checkc[TM, TN] = (rxm[:, newaxis] < M) && (rxn[newaxis, :] < N); *?(checkc)pc = (TYPE[TM, TN])c; } """ kernel = triton.kernel(src, ['C']) @staticmethod def _call(a, b, transpose_a, transpose_b, bench): # extract shapes shape_a = triton.shape(a) shape_b = triton.shape(b) M, Ka = shape_a[0], shape_a[1] Kb, N = shape_b[0], shape_b[1] # transpose shapes if transpose_a: M, Ka = Ka, M if transpose_b: Kb, N = N, Kb # contiguous dimensions lda = M if transpose_a else Ka ldb = Kb if transpose_b else N ldc = N # data-type dtype = a.dtype # allocate output c = triton.empty([M, N], dtype = dtype) # compute grid = lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))] # macros -- not necessary but makes kernel source-code simpler macros = {# handle A transposition 'USE_A' : '^a' if transpose_a else 'a', 'STRIDE_AK' : 'lda' if transpose_a else '1', 'STRIDE_AM' : '1' if transpose_a else 'lda', 'BROADCAST_AK': ':, newaxis' if transpose_a else 'newaxis, :', 'BROADCAST_AM': 'newaxis, :' if transpose_a else ':, newaxis', 'SHAPE_A' : 'TK, TM' if transpose_a else 'TM, TK', # handle B transposition 'USE_B' : '^b' if transpose_b else 'b', 'STRIDE_BK' : '1' if transpose_b else 'ldb', 'STRIDE_BN' : 'ldb' if transpose_b else '1', 'BROADCAST_BK': 'newaxis, :' if transpose_b else ':, newaxis', 'BROADCAST_BN': ':, newaxis' if transpose_b else 'newaxis, :', 'SHAPE_B' : 'TN, TK' if transpose_b else 'TK, TN'} _dot.kernel(a, b, c, 1., M, N, Ka, lda, ldb, ldc, grid, bench=bench, AT = transpose_a, BT = transpose_b, TYPE = dtype, TM = [64], TN = [128], TK = [8], **macros) return c @staticmethod def forward(ctx, a, b, transpose_a = False, transpose_b = False, bench = 0): ctx.save_for_backward(a, b) ctx.t_a = transpose_a ctx.t_b = transpose_b ctx.bench = bench return _dot._call(a, b, transpose_a, transpose_b, bench) @staticmethod def backward(ctx, dy): a, b = ctx.saved_tensors t_a, t_b = ctx.t_a, ctx.t_b bench = ctx.bench if not t_a and not t_b: da = _dot._call(dy, b, False, True, bench) db = _dot._call(a, dy, True, False, bench) elif not t_a and t_b: da = _dot._call(dy, b, False, False, bench) db = _dot._call(dy, a, True, False, bench) elif t_a and not t_b: da = _dot._call(b, dy, False, True, bench) db = _dot._call(a, dy, False, False, bench) elif t_a and t_b: da = _dot._call(b, dy, True, True, bench) db = _dot._call(dy, a, True, True, bench) else: assert False return da, db, None, None, None
def _call(a, b): dtype = a.dtype device = a.device # allocate output M, K = a.shape K, N = b.shape c = torch.empty((M, N), dtype=dtype, device=device) # handle non-contiguous inputs if necessary if a.stride(0) > 1 and a.stride(1) > 1: a = a.contiguous() if b.stride(0) > 1 and b.stride(1) > 1: b = b.contiguous() # kernel hash is_a_row = a.stride(1) == 1 is_b_row = b.stride(1) == 1 lda = a.stride(0) if is_a_row else a.stride(1) ldb = b.stride(0) if is_b_row else b.stride(1) ldc = c.stride(0) lda_pow2_div = _matmul.largest_pow2_divisor(lda) ldb_pow2_div = _matmul.largest_pow2_divisor(ldb) ldc_pow2_div = _matmul.largest_pow2_divisor(ldc) is_tk_div_k = K % 64 == 0 key = ( device, dtype, is_a_row, is_b_row, lda_pow2_div, ldb_pow2_div, ldc_pow2_div, is_tk_div_k, ) if key not in _matmul._kernels: defines = { "TYPE": dtype, "STRIDE_AM": "lda" if is_a_row else "1", "STRIDE_AK": "1" if is_a_row else "lda", "STRIDE_BK": "ldb" if is_b_row else "1", "STRIDE_BN": "1" if is_b_row else "ldb", "LDA_POW2_DIV": lda_pow2_div, "LDB_POW2_DIV": ldb_pow2_div, "LDC_POW2_DIV": ldc_pow2_div, "IS_TK_DIV_K": int(is_tk_div_k), } _matmul._kernels[key] = triton.kernel( _matmul.src, device, defines=defines, autotune_vals=_matmul._CONFIGS, autotune_key=["M", "N", "K"], ) kernel = _matmul._kernels[key] # # locks for split-k if device not in _matmul._locks: _matmul._locks[device] = torch.zeros(1024 * 1024, dtype=torch.int32, device=device) locks = _matmul._locks[device] # enqueue alpha = 1.0 args = [ a.data_ptr(), b.data_ptr(), c.data_ptr(), alpha, M, N, K, lda, ldb, ldc, locks.data_ptr(), ] grid = lambda opt: [ triton.cdiv(M, opt.TM) * triton.cdiv(N, opt.TN), 1, opt.SPLITK, ] kernel(*args, grid=grid) return c
def make_kernel(name, dtype, mask, expr_a, expr_b, expr_c, axes_m, axes_n, axes_k, axes_b, multipleof_a, multipleof_b, multipleof_c, stride_a_last, stride_b_last, stride_c_last, lut_mode_a, lut_mode_b, delta_a, delta_b, subscripted, varnames): use_lut_a = True use_lut_b = True src = "" if use_lut_a and lut_mode_a == _einsum.LUT_MODE.CONSTANT: src += f""" char __constant__* AD = calloc({4*len(delta_a)});""" if use_lut_b and lut_mode_b == _einsum.LUT_MODE.CONSTANT: src += f""" char __constant__* BD = calloc({4*len(delta_b)});""" src += f""" __global__ void {name}( TYPE * A __noalias __readonly __aligned(16) , TYPE * B __noalias __readonly __aligned(16) , TYPE * C , int * locks , float alpha , int matmul_m, int matmul_n, int matmul_k __multipleof(16) , int div_m """ for dim in [axes_m, axes_n, axes_k, axes_b]: for d in dim: src += f", int dim_{d}" src += "\n " for dim, name, mult in zip([expr_a, expr_b, expr_c], ['a', 'b', 'c'], [multipleof_a, multipleof_b, multipleof_c]): for d in range(len(dim) - 1): attr = f'__multipleof({mult})' src += f", int stride_{name}_{d} {attr}" src += "\n " if lut_mode_a == _einsum.LUT_MODE.SCALAR: src += f", int stride_a_inner __multipleof({multipleof_a})" src += f", int rem_delta_a __multipleof({multipleof_a})" elif lut_mode_a == _einsum.LUT_MODE.DRAM: src += ", int* AD __noalias __readonly __aligned(16)" src += "\n " if lut_mode_b == _einsum.LUT_MODE.SCALAR: src += f", int stride_b_inner __multipleof({multipleof_b})" src += f", int rem_delta_b __multipleof({multipleof_b})" elif lut_mode_b == _einsum.LUT_MODE.DRAM: src += ", int* BD" src += "\n" for ptr in subscripted: src += f", int* {ptr}" for name in varnames: src += f", int {name}" src += """) { // re-order outer program ids int grid_m = (matmul_m + TM - 1) / TM; int grid_n = (matmul_n + TN - 1) / TN; int pid_mn = get_program_id(0) / div_m; int pid_n = pid_mn % grid_n; int pid_m = (pid_mn / grid_n)*div_m + (get_program_id(0) % div_m); // get batch program id int pid_b = get_program_id(1); #if TZ == 1 int off_k = 0; #else // get reduction sub-group program id int pid_z = get_program_id(2); int grid_z = get_num_programs(2); int div_z = matmul_k / TZ; int rem_z = matmul_k % TZ; int off_k = pid_z * div_z; matmul_k = select(pid_z < rem_z, div_z, div_z + rem_z); #endif int rem_k = matmul_k % TK; // create ranges """ rk = 'r{}'.format(''.join(map(str, axes_k))) for axes, tile, off in zip( [axes_m, axes_n, axes_b, axes_k], ['TM', 'TN', 'TB', 'TK'], ['pid_m*TM', 'pid_n*TN', 'pid_b*TB', 'off_k']): currs = ''.join(map(str, axes)) if axes: src += f" int r{currs}[{tile}] = {off} + 0 ... {tile};\n" src += _einsum.unpack_cc(tile, axes, 'r', False) src += """ // initialize pointers to A int offa[TM, TK, TB] = """ for i, sym in enumerate(expr_a): ccode = _einsum.print_cc(sym, axes_m, axes_k, axes_b) stride = f'stride_a_{i}' if i < len( expr_a) - 1 else f'{stride_a_last}' if i > 0: src += ' + ' src += f"({ccode}) * {stride}\n " src += ';' src += """ TYPE *pa[TM, TK, TB] = A + offa;""" if use_lut_a and not lut_mode_a == _einsum.LUT_MODE.SCALAR: spec = '__constant__' if lut_mode_a == _einsum.LUT_MODE.CONSTANT else '' cast = '(int __constant__*)' if lut_mode_a == _einsum.LUT_MODE.CONSTANT else '' src += f""" // initialize pointers to A look-up table int offadelta[TK] = off_k + 0 ... TK; int {spec} *padelta[TK] = {cast}AD + offadelta; int incda[TM, TK, TB] = (*padelta)[newaxis, :, newaxis];""" src += """ // initialize pointers to B int offb[TK, TN, TB] = """ for i, sym in enumerate(expr_b): ccode = _einsum.print_cc(sym, axes_k, axes_n, axes_b) stride = f'stride_b_{i}' if i < len( expr_b) - 1 else f'{stride_b_last}' if i > 0: src += ' + ' src += f"({ccode}) * {stride}\n " src += ';' src += """ TYPE *pb[TK, TN, TB] = B + offb;""" if use_lut_b and not lut_mode_b == _einsum.LUT_MODE.SCALAR: spec = '__constant__' if lut_mode_b == _einsum.LUT_MODE.CONSTANT else '' cast = '(int __constant__*)' if lut_mode_b == _einsum.LUT_MODE.CONSTANT else '' src += f""" // initialize pointers to B look-up table int offbdelta[TK] = off_k + 0 ... TK; int *pbdelta[TK] = BD + offbdelta;""" src += f""" // prefetch int prefetch_k = select(rem_k > 0, rem_k, TK); bool checkm[TM] = r""" + ''.join(map(str, axes_m)) + f""" < matmul_m; bool checkn[TN] = r""" + ''.join(map(str, axes_n)) + f""" < matmul_n; bool checkk[TK] = {rk} < prefetch_k; bool checka[TM, TK, TB] = checkm[:, newaxis, newaxis] && checkk[newaxis, :, newaxis]; bool checkb[TK, TN, TB] = checkk[:, newaxis, newaxis] && checkn[newaxis, :, newaxis]; TYPE a[TM, TK, TB] = checka ? *pa : 0; TYPE b[TK, TN, TB] = checkb ? *pb : 0;""" if lut_mode_a == _einsum.LUT_MODE.SCALAR: src += """ pa += rem_delta_a;""" else: src += """ pa += incda; padelta += TK; incda = (*padelta)[newaxis, :, newaxis];""" if lut_mode_b == _einsum.LUT_MODE.SCALAR: src += """ pb += rem_delta_b;""" else: src += """ pb += (*pbdelta)[:, newaxis, newaxis]; pbdelta += TK;""" src += f""" // accumulate float acc[TM, TN, TB] = 0; for(int k = matmul_k; k > 0; k -= TK) {{ acc += a @ b; #ifdef MASK uint32 bits[TM, TN, TB] = bitcast<uint32[TM,TN,TB]>(acc); acc = bitcast<float[TM, TN, TB]>(bits & MASK); #endif checkk = k > TK; checka = checkm[:, newaxis, newaxis] && checkk[newaxis, :, newaxis]; checkb = checkk[:, newaxis, newaxis] && checkn[newaxis, :, newaxis]; a = *?(checka)pa; b = *?(checkb)pb;""" if lut_mode_a == _einsum.LUT_MODE.SCALAR: src += """ pa += stride_a_inner;""" else: src += """ pa += incda; padelta += TK; incda = (*padelta)[newaxis, :, newaxis];""" if lut_mode_b == _einsum.LUT_MODE.SCALAR: src += """ pb += stride_b_inner;""" else: src += """ pb += (*pbdelta)[:, newaxis, newaxis]; pbdelta += TK;""" src += f""" }} TYPE c[TM, TN, TB] = acc; // re-materialize ranges pid_mn = get_program_id(0) / div_m; pid_n = pid_mn % grid_n; pid_m = (pid_mn / grid_n)*div_m + (get_program_id(0) % div_m); """ for axes, tile, off in zip([axes_m, axes_n, axes_b], ['TM', 'TN', 'TB'], ['pid_m*TM', 'pid_n*TN', 'pid_b*TB']): currs = ''.join(map(str, axes)) if axes: src += f" r{currs} = {off} + 0 ... {tile};\n" src += _einsum.unpack_cc(tile, axes, 'r', True) src += """ // initialize pointers to C int offc[TM, TN, TB] = """ for i, sym in enumerate(expr_c): stride = f'stride_c_{i}' if i < len( expr_c) - 1 else f'{stride_c_last}' ccode = _einsum.print_cc(sym, axes_m, axes_n, axes_b) if i > 0: src += ' + ' src += f"({ccode}) * {stride}\n " src += ';' src += """ TYPE *pc[TM, TN, TB] = C + offc; // bounds-checking checkm = r""" + ''.join(map(str, axes_m)) + """ < matmul_m; checkn = r""" + ''.join(map(str, axes_n)) + """ < matmul_n; bool checkc[TM, TN, TB] = checkm[:, newaxis, newaxis] && checkn[newaxis, :, newaxis]; // write back #if TZ == 1 *?(checkc)pc = c; #else int *plock = locks + pid_mn + pid_b * get_num_programs(0); int *pcount = plock + 1024*1024; // spin for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1)); int count = *pcount; if(count == 0) *?(checkc)pc = c; else *?(checkc)pc = c + *?(checkc)pc; atomic_xchg(pcount, (count + 1) % (grid_z)); atomic_xchg(plock, 0); #endif } """ # compilation options TM, TN, TB, TZ = [16, 32, 64, 128], [16, 32, 64, 128], 1, [1, 4, 16] TK = 16 if dtype == torch.float16 else 8 defines = { 'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype } if mask is not None: defines['MASK'] = '{0:#0{1}x}'.format(mask, 10) # create kernel ret = triton.kernel(src, defines=defines) # set constant if use_lut_a and lut_mode_a == _einsum.LUT_MODE.CONSTANT: ret.set_constant('AD', delta_a) if use_lut_b and lut_mode_b == _einsum.LUT_MODE.CONSTANT: ret.set_constant('BD', delta_b) return ret
def _sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, luts, num_locks, widths, packs, bench, time): if trans_c: a, b = b, a trans_a, trans_b = not trans_b, not trans_a AS0 = a.size(0) AS1 = a.size(1) AS2 = a.size(3 if trans_a else 2) AS3 = a.size(2 if trans_a else 3) BS0 = b.size(0) BS1 = b.size(1) BS2 = b.size(3 if trans_b else 2) BS3 = b.size(2 if trans_b else 3) dtype = a.dtype if dtype == torch.float16 and AS3 % 64 > 0: raise ValueError( 'Reduction size for SDD must be a multiple of 64 in FLOAT16') if dtype == torch.float32 and AS3 % 16 > 0: raise ValueError( 'Reduction size for SDD must be a multiple of 16 in FLOAT32') # create kernel total_width = sum( [width * pack * pack for width, pack in zip(widths, packs)]) c = torch.empty((AS0, total_width, block, block), dtype=dtype, device=a.device) for lut, width, pack in zip(luts, widths, packs): num_lock = 1 key = (block, a.dtype, b.dtype, trans_a, trans_b, trans_c, pack) if key not in _sparse_matmul.sdd_cache: TK = { torch.float32: [8, 16], torch.float16: [16, 32, 64] }[dtype] defines = { 'TM': block * pack, 'TN': block * pack, 'TMN': block * block * pack * pack, 'BLOCK': block, 'TK': TK, 'TYPE': dtype, 'STRIDE_AM': '1' if trans_a else 'lda', 'STRIDE_AK': 'lda' if trans_a else '1', 'STRIDE_BN': 'ldb' if trans_b else '1', 'STRIDE_BK': '1' if trans_b else 'ldb', 'STRIDE_CM': 'ldc', 'STRIDE_CN': '1', 'SDD': True, 'TZ': 1, 'NAME': 'sdd_kernel' } _sparse_matmul.sdd_cache[key] = triton.kernel( src, defines=defines, num_warps=[1, 2, 4]) kernel = _sparse_matmul.sdd_cache[key] # create output locks = _sparse_matmul.get_locks(2 * width * AS0 * num_lock, a.device) # maximum grid size is 65535 # so operation might be decomposed into multiple # kernel calls max_width = 49152 total = 0 if bench else None for off_width in range(0, width, max_width): current = kernel( a, b, c, a.stride(2), b.stride(2), block, a.stride(0), b.stride(0), c.stride(0), a.stride(1), b.stride(1), c.stride(0), AS2, AS2, AS3, off_width, lut, locks, num_lock, grid=lambda opt: [opt.d('TZ'), min(max_width, width - off_width), AS0], bench=bench) total = total + current if bench else None time[0] = total # save for backward pass return c
class _batchnorm(triton.function): fwd_src = """ void fwdbatchnorm(float *Y, float *M, float *V, float *X, float *G, float *B, int N, float eps) { // pointers int c = get_program_id(1); int rm[TM] = 0 ... TM; float *px[TM] = X + rm + c*N; float* py[TM] = Y + rm + c*N; // compute mean float accm[TM] = 0; for(int i = 0; i < N; i = i + TM) accm = accm + *(px + i); float mean = (float)accm[+] / N; *(M + c) = mean; // compute variance float accv[TM] = 0; for(int i = 0; i < N; i = i + TM){ float x[TM] = *(px + i); x = x - mean; accv = accv + x*x; } float var = (float)accv[+] / N; *(V + c) = var; // Normalize batch float gamma = *(G + c); float beta = *(B + c); float rstdg = 1 / sqrtf(var + eps) * gamma; for(int i = 0; i < N; i = i + TM){ float x[TM] = *(px + i); float y[TM] = (x - mean)*rstdg + beta; *(py + i) = y; } } """ fwd_kernel = triton.kernel(fwd_src, ['Y', 'M', 'V']) bwd_src = """ void bwdbatchnorm(float *DX, float *DG, float *DB, float *DY, float *X, float *G, float *M, float *V, int N, float epsilon) { // pointers int c = get_program_id(1); int rx[TM] = 0 ... TM; int offset = c*N; float* px[TM] = X + rx + offset; float* pdy[TM] = DY + rx + offset; float* pdx[TM] = DX + rx + offset; // fetch statistics float gamma = *(G + c); float mean = *(M + c); float var = *(V + c); float rstd = 1 / sqrtf(var + epsilon); // compute dgamma and dbeta float acc_dg[TM] = 0; float acc_db[TM] = 0; for(int i = 0; i < N; i = i + TM){ float x[TM] = *(px + i); float dy[TM] = *(pdy + i); acc_dg += dy*(x - mean)*rstd; acc_db += dy; } float dg = acc_dg[+]; float db = acc_db[+]; *(DG + c) = dg; *(DB + c) = db; // compute dx for(int i = 0; i < N; i = i + TM){ float x[TM] = *(px + i); float dy[TM] = *(pdy + i); float xhat[TM] = (x - mean) * rstd; float xtmp[TM] = (xhat * dg + db) / N; float dx[TM] = (dy - xtmp) * rstd * gamma; *(pdx + i) = dx; } } """ bwd_kernel = triton.kernel(bwd_src, ['DX', 'DG', 'DB']) @staticmethod def forward(ctx, x, gamma, beta, eps): shape = triton.shape(x) dtype = x.dtype # allocate outputs C, H, W, B = shape[0], shape[1], shape[2], shape[3] y = triton.empty(shape, dtype=dtype) mean = triton.empty([C], dtype=dtype) var = triton.empty([C], dtype=dtype) # execute kernels _batchnorm.fwd_kernel(y, mean, var, x, gamma, beta, H*W*B, eps, grid = lambda opt: [1, C], defines = {'TM': 128}) # save ctx.save_for_backward(x, gamma, beta, mean, var) ctx.eps = eps return y @staticmethod def backward(ctx, dy): # retrieve info x, gamma, beta, mean, var = ctx.saved_tensors eps = ctx.eps # allocate result dx = triton.empty(triton.shape(x), dtype=x.dtype) dgamma = triton.empty(triton.shape(gamma), dtype=gamma.dtype) dbeta = triton.empty(triton.shape(beta), dtype=beta.dtype) # execute C, H, W, B = triton.shape(x) _batchnorm.bwd_kernel(dx, dgamma, dbeta, dy, x, gamma, mean, var, H*W*B, eps, grid = lambda opt: [1, C], defines = {'TM': 128}) return dx, dgamma, dbeta, None