def get_kernel(block, dtype, device): key = (block, dtype, device) if key not in kernels: src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c')) defines = {'BLOCK': block, 'TYPE': dtype} kernels[key] = triton.kernel(src, device=device, defines=defines) return kernels[key]
def make_kernel(device, dtype, n_cols, cache, name): rounded = next_power_of_2(n_cols) div = largest_pow2_divisor(n_cols) key = (dtype, rounded, div) if key not in cache: fname = os.path.join(os.path.dirname(__file__), "cross_entropy.c") src = triton.read(fname, kernel_names=[name]) infinities = { torch.float16: "F16_INFINITY", torch.float32: "F32_INFINITY", } defines = { "TILE": rounded, "TYPE": dtype, "INFINITY": infinities[dtype], "N_COLS_MULT": div } cache[key] = triton.kernel(src, device=device, defines=defines, num_warps=4) return cache[key]
class _conv(torch.autograd.Function): src = triton.read(os.path.join(os.path.dirname(__file__), 'conv.c')) kernel = dict() @staticmethod def unpack(IDX, CI, R, S): s = IDX % S cr = IDX // S r = cr % R ci = cr // R return ci, r, s @staticmethod def forward(ctx, a, b, pad, stride): # create kernel if necessary dtype = a.dtype device = a.device # shapes Z, CI, H, W = a.shape _, R, S, CO = b.shape P = (H + 2 * pad[0] - R) // stride[0] + 1 Q = (W + 2 * pad[1] - S) // stride[1] + 1 # compile kernel if (dtype, device) not in _conv.kernel: TK = 16 defines = { 'TYPE': dtype, 'TM': [32, 64, 128], 'TN': [32, 64, 128], 'TK': [TK], 'TZ': [1], 'HH': H, 'WW': W, 'PP': P, 'QQ': Q, 'SS': S, 'RR': R, } idx = torch.arange(CI * R * S) ci, r, s = _conv.unpack(idx, CI, R, S) nci, nr, ns = _conv.unpack(idx + TK, CI, R, S) delta = (nci - ci) * a.stride(1) + (nr - r) * a.stride(2) + ( ns - s) * a.stride(3) delta = delta.type(torch.int32).cuda() _conv.kernel[dtype] = (delta, triton.kernel(_conv.src, device=device, num_warps=[4], defines=defines)) delta, kernel = _conv.kernel[dtype] # allocate output c = torch.empty([Z, CO, P, Q], dtype=dtype, device=device) # enqueue kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), 1., Z * P * Q, CO, CI * R * S, pad[0], pad[1], stride[0], stride[1], delta.data_ptr(), a.stride(0), a.stride(1), a.stride(2), a.stride(3), b.stride(0), b.stride(1), b.stride(2), b.stride(3), c.stride(0), c.stride(1), c.stride(2), c.stride(3), grid=lambda opt: [triton.cdiv(Z * P * Q, opt.TM), triton.cdiv(CO, opt.TN)]) return c
class _softmax_xent_loss(torch.autograd.Function): """This modifies logits in place, turning them into negative logprobs on the forward pass. It should not copy the logits at all. """ fwd_src = triton.read( os.path.join( os.path.dirname(__file__), "softmax_xent_kernels.c", ), kernel_names=["softmax_fwd"], ) bwd_src = triton.read( os.path.join( os.path.dirname(__file__), "softmax_xent_kernels.c", ), kernel_names=["softmax_bwd"], ) # Need TILE = n_vocab for this approach to work: input_config_to_kernel_fwd = {} input_config_to_kernel_bwd = {} @classmethod def forward(cls, ctx, logits, indices): # make sure we can use triton n_vocab = logits.shape[-1] assert (indices.dtype == torch.int64 ), "Indices are expected to be of type long." # compile a new kernel if needed; otherwise load from a cache if not (logits.dtype, n_vocab) in cls.input_config_to_kernel_fwd: infinities = { torch.float16: "F16_INFINITY", torch.float32: "F32_INFINITY", } cls.input_config_to_kernel_fwd[( logits.dtype, n_vocab)] = triton.kernel( cls.fwd_src, device=logits.device, defines={ "TILE": make_power_of_two(n_vocab), "TYPE": logits.dtype, "INFINITY": infinities[logits.dtype], }, ) kernel_fwd = cls.input_config_to_kernel_fwd[(logits.dtype, n_vocab)] # flatten logits and be prepared to restore them to their original shape original_logits_shape = logits.shape if len(original_logits_shape) > 2: logits = logits.reshape((-1, n_vocab)) indices = indices.reshape((-1, )) # run the kernel and assign the result in place result = torch.empty_like(indices, dtype=logits.dtype).cuda() neg_logprobs = torch.empty_like(logits, dtype=logits.dtype).cuda() grid = lambda opt: (logits.shape[0], ) kernel_fwd( logits.data_ptr(), neg_logprobs.data_ptr(), indices.data_ptr(), result.data_ptr(), n_vocab, grid=grid, ) if len(original_logits_shape) > 2: logits = logits.reshape(original_logits_shape) indices = indices.reshape(*original_logits_shape[:-1]) ctx.save_for_backward(neg_logprobs, indices) ctx.original_logits_shape = original_logits_shape return result @classmethod def backward(cls, ctx, dneg_logprobs): """We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k] so we initialize the gradient as neg_logprobs, so we can just exponentiate to get p[k], which is most of what we need... neg_logprobs will be modified in place to become the gradient we want """ # load saved tensors and ensure correct types neg_logprobs, indices = ctx.saved_tensors original_logits_shape = ctx.original_logits_shape assert ( dneg_logprobs.dtype == neg_logprobs.dtype ), f"Backward flowing derivatives of type {dneg_logprobs.dtype} != logits type {neg_logprobs.dtype}" n_vocab = neg_logprobs.shape[-1] # generate or load kernel if not (neg_logprobs.dtype, n_vocab) in cls.input_config_to_kernel_bwd: cls.input_config_to_kernel_bwd[( neg_logprobs.dtype, n_vocab)] = triton.kernel( cls.bwd_src, device=neg_logprobs.device, defines={ "TILE": make_power_of_two(n_vocab), "TYPE": neg_logprobs.dtype, }, ) kernel_bwd = cls.input_config_to_kernel_bwd[(neg_logprobs.dtype, n_vocab)] grid = lambda opt: (neg_logprobs.shape[0], ) # neg_logprobs will be modified in place to become our gradient: kernel_bwd( neg_logprobs.data_ptr(), indices.data_ptr(), dneg_logprobs.data_ptr(), n_vocab, grid=grid, ) # reshape results based on shape of original logits passed to forward if len(original_logits_shape) > 2: neg_logprobs = neg_logprobs.reshape(original_logits_shape) return neg_logprobs, torch.zeros_like(indices)
import triton import triton._C.libtriton as libtriton import torch import os import math src = triton.read(os.path.join(os.path.dirname(__file__), 'matmul.c')) ############## # MAIN API # ############## class _matmul(torch.autograd.Function): sdd_cache = dict() dsd_cache = dict() dds_cache = dict() locks = dict() # Given an array sizes representing reduction size for each # column of a block-mode matrix multiplication, # performs load-balancing to achieve more smaller reductions # between `seg_size` elements @staticmethod def load_balance(sizes, block): # segment size # heuristics taken from OpenAI blocksparse code # https://github.com/openai/blocksparse/blob/master/blocksparse/matmul.py#L95 max_size = sizes.max() min_size = sizes[sizes != 0].min() #if max_size > min_size * 2.0: # seg_max = max(triton.cdiv(max_size, 4), min_size*2)
class _matmul(torch.autograd.Function): src = triton.read(os.path.join(os.path.dirname(__file__), "matmul.c")) _DEFAULT_CONFIGS = [ triton.config(defines={ "TM": "128", "TN": "128", "TK": "32", "SPLITK": "1" }, num_warps=4), triton.config(defines={ 'TM': '64', 'TN': '128', 'TK': '32', 'SPLITK': '1' }, num_warps=4), triton.config(defines={ 'TM': '128', 'TN': '64', 'TK': '32', 'SPLITK': '1' }, num_warps=4), triton.config(defines={ 'TM': '64', 'TN': '64', 'TK': '64', 'SPLITK': '1' }, num_warps=4), triton.config(defines={ 'TM': '32', 'TN': '128', 'TK': '64', 'SPLITK': '1' }, num_warps=4), triton.config(defines={ 'TM': '128', 'TN': '32', 'TK': '64', 'SPLITK': '1' }, num_warps=4), triton.config(defines={ 'TM': '64', 'TN': '32', 'TK': '64', 'SPLITK': '1' }, num_warps=2), triton.config(defines={ 'TM': '32', 'TN': '64', 'TK': '64', 'SPLITK': '1' }, num_warps=2), triton.config(defines={ 'TM': '32', 'TN': '128', 'TK': '32', 'SPLITK': '2' }, num_warps=4), triton.config(defines={ 'TM': '32', 'TN': '128', 'TK': '32', 'SPLITK': '2' }, num_warps=4), triton.config(defines={ 'TM': '128', 'TN': '32', 'TK': '32', 'SPLITK': '4' }, num_warps=4), triton.config(defines={ 'TM': '128', 'TN': '32', 'TK': '32', 'SPLITK': '4' }, num_warps=4), ] _CONFIGS = _DEFAULT_CONFIGS @staticmethod def largest_pow2_divisor(N): if N % 8 == 0: return 8 if N % 4 == 0: return 4 if N % 2 == 0: return 2 return 1 _locks = dict() _kernels = dict() @staticmethod def _call(a, b): dtype = a.dtype device = a.device # allocate output M, K = a.shape K, N = b.shape c = torch.empty((M, N), dtype=dtype, device=device) # handle non-contiguous inputs if necessary if a.stride(0) > 1 and a.stride(1) > 1: a = a.contiguous() if b.stride(0) > 1 and b.stride(1) > 1: b = b.contiguous() # kernel hash is_a_row = a.stride(1) == 1 is_b_row = b.stride(1) == 1 lda = a.stride(0) if is_a_row else a.stride(1) ldb = b.stride(0) if is_b_row else b.stride(1) ldc = c.stride(0) lda_pow2_div = _matmul.largest_pow2_divisor(lda) ldb_pow2_div = _matmul.largest_pow2_divisor(ldb) ldc_pow2_div = _matmul.largest_pow2_divisor(ldc) is_tk_div_k = K % 64 == 0 key = ( device, dtype, is_a_row, is_b_row, lda_pow2_div, ldb_pow2_div, ldc_pow2_div, is_tk_div_k, ) if key not in _matmul._kernels: defines = { "TYPE": dtype, "STRIDE_AM": "lda" if is_a_row else "1", "STRIDE_AK": "1" if is_a_row else "lda", "STRIDE_BK": "ldb" if is_b_row else "1", "STRIDE_BN": "1" if is_b_row else "ldb", "LDA_POW2_DIV": lda_pow2_div, "LDB_POW2_DIV": ldb_pow2_div, "LDC_POW2_DIV": ldc_pow2_div, "IS_TK_DIV_K": int(is_tk_div_k), } _matmul._kernels[key] = triton.kernel( _matmul.src, device, defines=defines, autotune_vals=_matmul._CONFIGS, autotune_key=["M", "N", "K"], ) kernel = _matmul._kernels[key] # # locks for split-k if device not in _matmul._locks: _matmul._locks[device] = torch.zeros(1024 * 1024, dtype=torch.int32, device=device) locks = _matmul._locks[device] # enqueue alpha = 1.0 args = [ a.data_ptr(), b.data_ptr(), c.data_ptr(), alpha, M, N, K, lda, ldb, ldc, locks.data_ptr(), ] grid = lambda opt: [ triton.cdiv(M, opt.TM) * triton.cdiv(N, opt.TN), 1, opt.SPLITK, ] kernel(*args, grid=grid) return c @staticmethod def forward(ctx, a, b): c = _matmul._call(a, b) return c
class _matmul(torch.autograd.Function): src = triton.read(os.path.join(os.path.dirname(__file__), 'matmul.c')) TM = [128] TN = [128] TK = [32] TZ = 1 num_warps = [4] @staticmethod def largest_pow2_divisor(N): if N % 8 == 0: return 8 if N % 4 == 0: return 4 if N % 2 == 0: return 2 return 1 _locks = dict() _kernels = dict() @staticmethod def _call(a, b): dtype = a.dtype device = a.device # allocate output M, K = a.shape K, N = b.shape c = torch.empty((M, N), dtype=dtype, device=device) # handle non-contiguous inputs if necessary if a.stride(0) > 1 and a.stride(1) > 1: a = a.contiguous() if b.stride(0) > 1 and b.stride(1) > 1: b = b.contiguous() # kernel hash is_a_row = a.stride(1) == 1 is_b_row = b.stride(1) == 1 lda = a.stride(0) if is_a_row else a.stride(1) ldb = b.stride(0) if is_b_row else b.stride(1) ldc = c.stride(0) lda_pow2_div = _matmul.largest_pow2_divisor(lda) ldb_pow2_div = _matmul.largest_pow2_divisor(ldb) ldc_pow2_div = _matmul.largest_pow2_divisor(ldc) is_tk_div_k = K % 32 == 0 key = (device, dtype, is_a_row, is_b_row, lda_pow2_div, ldb_pow2_div, ldc_pow2_div, is_tk_div_k) if key not in _matmul._kernels: defines = { 'TYPE': dtype, 'STRIDE_AM': 'lda' if is_a_row else '1', 'STRIDE_AK': '1' if is_a_row else 'lda', 'STRIDE_BK': 'ldb' if is_b_row else '1', 'STRIDE_BN': '1' if is_b_row else 'ldb', 'LDA_POW2_DIV': lda_pow2_div, 'LDB_POW2_DIV': ldb_pow2_div, 'LDC_POW2_DIV': ldc_pow2_div, 'TM': _matmul.TM, 'TN': _matmul.TN, 'TK': _matmul.TK, 'TZ': _matmul.TZ, 'IS_TK_DIV_K': is_tk_div_k } _matmul._kernels[key] = triton.kernel(_matmul.src, device, num_warps=_matmul.num_warps, defines=defines) kernel = _matmul._kernels[key] # # locks for split-k if device not in _matmul._locks: _matmul._locks[device] = torch.zeros(1024 * 1024, dtype=torch.int32, device=device) locks = _matmul._locks[device] # enqueue alpha = 1. args = [ a.data_ptr(), b.data_ptr(), c.data_ptr(), alpha, M, N, K, lda, ldb, ldc, locks.data_ptr() ] grid = lambda opt: [ triton.cdiv(M, opt.TM) * triton.cdiv(N, opt.TN), 1, 1 ] kernel(*args, grid=grid) return c @staticmethod def forward(ctx, a, b): c = _matmul._call(a, b) return c
class _softmax_xent_loss_in_place(torch.autograd.Function): """This modifies logits in place, turning them into negative logprobs on the forward pass. It should not copy the logits at all. """ fwd_src = triton.read( os.path.join(os.path.dirname(__file__), "softmax_xent_kernels_in_place.c"), kernel_names=["softmax_fwd"], ) bwd_src = triton.read( os.path.join(os.path.dirname(__file__), "softmax_xent_kernels_in_place.c"), kernel_names=["softmax_bwd"], ) # Need TILE = n_vocab for this approach to work: input_config_to_kernel_fwd = {} input_config_to_kernel_bwd = {} @classmethod def forward(cls, ctx, logits, indices): n_vocab = logits.shape[-1] assert indices.dtype == torch.int64 assert ( n_vocab % 128 == 0), "Number of logit options must be divisible by 128." if not (logits.dtype, n_vocab) in cls.input_config_to_kernel_fwd: cls.input_config_to_kernel_fwd[(logits.dtype, n_vocab)] = triton.kernel( cls.fwd_src, device=logits.device, defines={ "TILE": n_vocab, "TYPE": logits.dtype, }, num_warps=[4], ) kernel_fwd = cls.input_config_to_kernel_fwd[(logits.dtype, n_vocab)] result = torch.zeros_like(indices, dtype=logits.dtype).cuda() grid = lambda opt: (logits.shape[0], ) kernel_fwd(logits.data_ptr(), indices.data_ptr(), result.data_ptr(), grid=grid) # logits -> neg_logprobs via an in place modification by kernel_fwd ctx.save_for_backward(logits, indices) return result @classmethod def backward(cls, ctx, dneg_logprobs): """We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k] so we initialize the gradient as neg_logprobs, so we can just exponentiate to get p[k], which is most of what we need... neg_logprobs will be modified in place to become the gradient we want """ neg_logprobs, indices = ctx.saved_tensors assert indices.dtype == torch.int64 assert ( dneg_logprobs.dtype == neg_logprobs.dtype ), f"Backward flowing derivatives of type {dneg_logprobs.dtype} != logits type {neg_logprobs.dtype}" n_vocab = neg_logprobs.shape[-1] N = neg_logprobs.numel() if not (neg_logprobs.dtype, n_vocab) in cls.input_config_to_kernel_bwd: cls.input_config_to_kernel_bwd[(neg_logprobs.dtype, n_vocab)] = triton.kernel( cls.bwd_src, device=neg_logprobs.device, defines={ "TILE": n_vocab, "TYPE": neg_logprobs.dtype }, num_warps=[4], ) kernel_bwd = cls.input_config_to_kernel_bwd[(neg_logprobs.dtype, n_vocab)] grid = lambda opt: (triton.cdiv(N, opt.TILE), ) # neg_logprobs will be modified in place to become our gradient: kernel_bwd( neg_logprobs.data_ptr(), indices.data_ptr(), dneg_logprobs.data_ptr(), grid=grid, ) return neg_logprobs, torch.zeros_like(indices)
import triton import torch import os fwd_src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'), kernel_names=['forward']) fwd_kernels = dict() bwd_src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'), kernel_names=['backward']) bwd_kernels = dict() class _softmax(torch.autograd.Function): @staticmethod def next_power_of_2(n): n -= 1 n |= n >> 1 n |= n >> 2 n |= n >> 4 n |= n >> 8 n |= n >> 16 n += 1 return n @staticmethod def make_lut(layout, block, device): _empty = torch.tensor([], dtype=torch.int64, device=layout.device) sizes = _empty.clone() # sizes along rows for h in range(layout.shape[0]):