def forward(ctx, x, gamma, beta, eps): # lazy compilation of kernel if _batchnorm.fwd_kernel is None: _batchnorm.fwd_kernel = triton.kernel(fwd_src, defines={'TM': 128}) # shapes shape = triton.shape(x) dtype = x.dtype # allocate outputs C, H, W, B = shape[0], shape[1], shape[2], shape[3] y = triton.empty(shape, dtype=dtype) mean = triton.empty([C], dtype=dtype) var = triton.empty([C], dtype=dtype) # execute kernels _batchnorm.fwd_kernel(y, mean, var, x, gamma, beta, H * W * B, eps, grid=lambda opt: [1, C]) # save ctx.save_for_backward(x, gamma, beta, mean, var) ctx.eps = eps return y
def backward(ctx, dy): # lazy compilation of kernel if _batchnorm.bwd_kernel is None: _batchnorm.bwd_kernel = triton.kernel(bwd_src, defines={'TN': 128}) # retrieve info x, gamma, beta, mean, var = ctx.saved_tensors eps = ctx.eps # allocate result dx = triton.empty(triton.shape(x), dtype=x.dtype) dgamma = triton.empty(triton.shape(gamma), dtype=gamma.dtype) dbeta = triton.empty(triton.shape(beta), dtype=beta.dtype) # execute C, H, W, B = triton.shape(x) _batchnorm.bwd_kernel(dx, dgamma, dbeta, dy, x, gamma, mean, var, H * W * B, eps, grid=lambda opt: [1, C]) return dx, dgamma, dbeta, None
def forward(ctx, einsum, a, b, shape_c, **kwargs): bench = kwargs['bench'] if 'bench' in kwargs else False arrays = kwargs['arrays'] if 'arrays' in kwargs else dict() # allocate output dtype = a.dtype c = triton.empty(shape_c, dtype=dtype) # compile einsum instance cache = _einsum.instance_cache key = (einsum, dtype, a.stride(), b.stride(), c.stride(), a.shape, b.shape, c.shape) if key not in cache: cache[key] = _einsum.instance(einsum, dtype, a.stride(), b.stride(), c.stride(), a.shape, b.shape, c.shape, arrays) instance = cache[key] instance.run(a, b, c, bench) # save information in context ctx.flops = instance.flops ctx.sym_a = instance.sym_a ctx.sym_b = instance.sym_b ctx.sym_c = instance.sym_c ctx.matmul_B = instance.matmul_B ctx.matmul_M = instance.matmul_M ctx.matmul_N = instance.matmul_N ctx.matmul_K = instance.matmul_K ctx.bench = bench ctx.save_for_backward(a, b) return c
def forward(ctx, x, gamma, beta, eps): shape = triton.shape(x) dtype = x.dtype # allocate outputs C, H, W, B = shape[0], shape[1], shape[2], shape[3] y = triton.empty(shape, dtype=dtype) mean = triton.empty([C], dtype=dtype) var = triton.empty([C], dtype=dtype) # execute kernels _batchnorm.fwd_kernel(y, mean, var, x, gamma, beta, H * W * B, eps, lambda opt: [1, C], TM=128) # save ctx.save_for_backward(x, gamma, beta, mean, var) ctx.eps = eps return y
def backward(ctx, dy): # retrieve info x, gamma, beta, mean, var = ctx.saved_tensors eps = ctx.eps # allocate result dx = triton.empty(triton.shape(x), dtype=x.dtype) dgamma = triton.empty(triton.shape(gamma), dtype=gamma.dtype) dbeta = triton.empty(triton.shape(beta), dtype=beta.dtype) # execute C, H, W, B = triton.shape(x) _batchnorm.bwd_kernel(dx, dgamma, dbeta, dy, x, gamma, mean, var, H * W * B, eps, lambda opt: [1, C], TM=128) return dx, dgamma, dbeta, None
def _call(a, b): # create kernel if necessary dtype = a.dtype if dtype not in _dot.kernel: defines = { 'TYPE': dtype, 'STRIDE_AM': '1', 'STRIDE_AK': 'lda', 'STRIDE_BN': '1', 'STRIDE_BK': 'ldb', 'TM': [64, 128], 'TN': [64, 128], 'TK': [8, 16], 'TZ': [1] } _dot.kernel[dtype] = triton.kernel(_dot.src, num_warps=[4], defines=defines) kernel = _dot.kernel[dtype] # allocate output M, K = a.shape K, N = b.shape c = triton.empty([M, N], dtype=dtype) # enqueue grid = lambda opt: [ triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN')) ] time = kernel(a, b, c, 1., M, N, K, a.stride(0), b.stride(0), c.stride(0), grid=grid, bench=100) print(2 * M * N * K / (time * 1e-6) * 1e-9) return c
def _call(a, b, transpose_a, transpose_b, bench): # extract shapes shape_a = triton.shape(a) shape_b = triton.shape(b) M, Ka = shape_a[0], shape_a[1] Kb, N = shape_b[0], shape_b[1] # transpose shapes if transpose_a: M, Ka = Ka, M if transpose_b: Kb, N = N, Kb # contiguous dimensions lda = M if transpose_a else Ka ldb = Kb if transpose_b else N ldc = N # data-type dtype = a.dtype # allocate output c = triton.empty([M, N], dtype = dtype) # compute grid = lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))] # macros -- not necessary but makes kernel source-code simpler macros = {# handle A transposition 'USE_A' : '^a' if transpose_a else 'a', 'STRIDE_AK' : 'lda' if transpose_a else '1', 'STRIDE_AM' : '1' if transpose_a else 'lda', 'BROADCAST_AK': ':, newaxis' if transpose_a else 'newaxis, :', 'BROADCAST_AM': 'newaxis, :' if transpose_a else ':, newaxis', 'SHAPE_A' : 'TK, TM' if transpose_a else 'TM, TK', # handle B transposition 'USE_B' : '^b' if transpose_b else 'b', 'STRIDE_BK' : '1' if transpose_b else 'ldb', 'STRIDE_BN' : 'ldb' if transpose_b else '1', 'BROADCAST_BK': 'newaxis, :' if transpose_b else ':, newaxis', 'BROADCAST_BN': ':, newaxis' if transpose_b else 'newaxis, :', 'SHAPE_B' : 'TN, TK' if transpose_b else 'TK, TN'} _dot.kernel(a, b, c, 1., M, N, Ka, lda, ldb, ldc, grid, bench=bench, AT = transpose_a, BT = transpose_b, TYPE = dtype, TM = [64], TN = [128], TK = [8], **macros) return c
def forward(ctx, a, b, pad, stride, time): # create kernel if necessary dtype = a.dtype # shapes Z, CI, H, W = a.shape _, R, S, CO = b.shape P = (H + 2 * pad[0] - R) // stride[0] + 1 Q = (W + 2 * pad[1] - S) // stride[1] + 1 # compile kernel if dtype not in _conv.kernel: TK = 8 defines = { 'TYPE': dtype, 'TM': [16, 32, 64, 128], 'TN': [16, 32, 64, 128], 'TK': [TK], 'TZ': [1], 'HH': H, 'WW': W, 'PP': P, 'QQ': Q, 'SS': S, 'RR': R, } idx = torch.arange(CI * R * S) ci, r, s = _conv.unpack(idx, CI, R, S) nci, nr, ns = _conv.unpack(idx + TK, CI, R, S) delta = (nci - ci) * a.stride(1) + (nr - r) * a.stride(2) + ( ns - s) * a.stride(3) delta = delta.type(torch.int32).cuda() _conv.kernel[dtype] = (delta, triton.kernel(_conv.src, num_warps=[2, 4], defines=defines)) delta, kernel = _conv.kernel[dtype] # allocate output c = triton.empty([Z, CO, P, Q], dtype=dtype) # enqueue grid = lambda opt: [ triton.cdiv(Z * P * Q, opt.d('TM')), triton.cdiv(CO, opt.d('TN')) ] time[0] = kernel(a, b, c, 1., Z * P * Q, CO, CI * R * S, pad[0], pad[1], stride[0], stride[1], delta, a.stride(0), a.stride(1), a.stride(2), a.stride(3), b.stride(0), b.stride(1), b.stride(2), b.stride(3), c.stride(0), c.stride(1), c.stride(2), c.stride(3), grid=grid, bench=100) return c
def _call(a, b, pad_d, pad_h, pad_w, stride_d, stride_h, stride_w, upsample_d, upsample_h, upsample_w, a_layout, b_layout, c_layout): # input shapes shape_a = list(triton.shape(a)) shape_b = list(triton.shape(b)) dim = len(shape_a) - 2 # indices an, ac, ad, ah, aw = [a_layout.find(x) for x in 'ncdhw'] bk, bc, bd, bh, bw = [b_layout.find(x) for x in 'kctrs'] cn, ck, cd, ch, cw = [c_layout.find(x) for x in 'nkdhw'] # extract shapes if dim == 2: shape_a.insert(ad, 1) if dim == 2: shape_b.insert(bd, 1) # output shape shape_c = [0] * 5 shape_c[cn] = shape_a[an] shape_c[ck] = shape_b[bk] shape_c[cd] = (shape_a[ad] * upsample_d - shape_b[bd] + 1 + 2 * pad_d + stride_d - 1) // stride_d shape_c[ch] = (shape_a[ah] * upsample_h - shape_b[bh] + 1 + 2 * pad_h + stride_h - 1) // stride_h shape_c[cw] = (shape_a[aw] * upsample_w - shape_b[bw] + 1 + 2 * pad_w + stride_w - 1) // stride_w # strides stride_a = _conv._extract_strides(shape_a) stride_b = _conv._extract_strides(shape_b) stride_c = _conv._extract_strides(shape_c) # tiling parameters TM = [32] TN = [32] TK = 8 # pointer deltas for a delta_a = _conv._delta_a(upsample_d, upsample_h, upsample_w, bc, bd, bh, bw, ac, ad, ah, aw, stride_a, shape_b, TK) delta_a = triton.fw.torch.from_numpy(delta_a).cuda() # delta increments for a inc_a = np.arange(delta_a.shape[-1] - TK, dtype=np.int32) inc_a = ((inc_a + TK) % inc_a.size) - inc_a inc_a = triton.fw.torch.from_numpy(inc_a).cuda() # allocate output if dim == 2: shape_c.pop(cd) c = triton.empty(shape_c, dtype=a.dtype) if dim == 2: shape_c.insert(cd, 1) # execute kernel trans_b = False is_wgrad = False is_blut = False macros = { 'UPAR': 'stride_h' if is_wgrad else '1', 'UPAS': '******' if is_wgrad else '1', 'UPAH': '' if is_wgrad else 'stride_h', 'UPAW': '' if is_wgrad else 'stride_w', 'LUT_SIZE': delta_a.shape[-1], 'TM': TM, 'TN': TN, 'TK': TK, 'A_TYPE': 'float', 'B_TYPE': 'float' } MATMUL_M = shape_c[cn] * shape_c[cd] * shape_c[ch] * shape_c[cw] MATMUL_N = shape_c[ck] MATMUL_K = shape_b[bc] * shape_b[bd] * shape_b[bh] * shape_b[bw] _conv.kernel( a, b, c, # matrix multiplication shapes MATMUL_M, MATMUL_N, MATMUL_K, # shapes for a shape_a[ah], shape_a[aw], # shapes for b shape_b[bh], shape_b[bw], # chapes for c shape_c[ch], shape_c[cw], shape_c[cn], # strides for a stride_a[an], stride_a[ac], stride_a[ad + 0], stride_a[ad + 1], stride_a[ad + 2], # strides for b stride_b[bc], stride_b[bd + 0], stride_b[bd + 1], stride_b[bd + 2], stride_b[bk], # strides for c stride_c[cn], stride_c[ck], stride_c[cd], stride_c[cd + 1], stride_c[cd + 2], # padding pad_h, pad_w, # striding stride_h, stride_w, # upsampling upsample_h, upsample_w, 0, 0, 0, 0, 0, 0, # look-up table delta_a, inc_a, lambda opt: [ triton.cdiv(MATMUL_M, opt.d('TM')), triton.cdiv(MATMUL_N, opt.d('TN')) ], **macros) return c