def backward(ctx, dy): # lazy compilation of kernel if _batchnorm.bwd_kernel is None: _batchnorm.bwd_kernel = triton.kernel(bwd_src, defines={'TN': 128}) # retrieve info x, gamma, beta, mean, var = ctx.saved_tensors eps = ctx.eps # allocate result dx = triton.empty(triton.shape(x), dtype=x.dtype) dgamma = triton.empty(triton.shape(gamma), dtype=gamma.dtype) dbeta = triton.empty(triton.shape(beta), dtype=beta.dtype) # execute C, H, W, B = triton.shape(x) _batchnorm.bwd_kernel(dx, dgamma, dbeta, dy, x, gamma, mean, var, H * W * B, eps, grid=lambda opt: [1, C]) return dx, dgamma, dbeta, None
def forward(ctx, x, gamma, beta, eps): # lazy compilation of kernel if _batchnorm.fwd_kernel is None: _batchnorm.fwd_kernel = triton.kernel(fwd_src, defines={'TM': 128}) # shapes shape = triton.shape(x) dtype = x.dtype # allocate outputs C, H, W, B = shape[0], shape[1], shape[2], shape[3] y = triton.empty(shape, dtype=dtype) mean = triton.empty([C], dtype=dtype) var = triton.empty([C], dtype=dtype) # execute kernels _batchnorm.fwd_kernel(y, mean, var, x, gamma, beta, H * W * B, eps, grid=lambda opt: [1, C]) # save ctx.save_for_backward(x, gamma, beta, mean, var) ctx.eps = eps return y
def _call(a, b, transpose_a, transpose_b, bench): # extract shapes shape_a = triton.shape(a) shape_b = triton.shape(b) M, Ka = shape_a[0], shape_a[1] Kb, N = shape_b[0], shape_b[1] # transpose shapes if transpose_a: M, Ka = Ka, M if transpose_b: Kb, N = N, Kb # contiguous dimensions lda = M if transpose_a else Ka ldb = Kb if transpose_b else N ldc = N # data-type dtype = a.dtype # allocate output c = triton.empty([M, N], dtype = dtype) # compute grid = lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))] # macros -- not necessary but makes kernel source-code simpler macros = {# handle A transposition 'USE_A' : '^a' if transpose_a else 'a', 'STRIDE_AK' : 'lda' if transpose_a else '1', 'STRIDE_AM' : '1' if transpose_a else 'lda', 'BROADCAST_AK': ':, newaxis' if transpose_a else 'newaxis, :', 'BROADCAST_AM': 'newaxis, :' if transpose_a else ':, newaxis', 'SHAPE_A' : 'TK, TM' if transpose_a else 'TM, TK', # handle B transposition 'USE_B' : '^b' if transpose_b else 'b', 'STRIDE_BK' : '1' if transpose_b else 'ldb', 'STRIDE_BN' : 'ldb' if transpose_b else '1', 'BROADCAST_BK': 'newaxis, :' if transpose_b else ':, newaxis', 'BROADCAST_BN': ':, newaxis' if transpose_b else 'newaxis, :', 'SHAPE_B' : 'TN, TK' if transpose_b else 'TK, TN'} _dot.kernel(a, b, c, 1., M, N, Ka, lda, ldb, ldc, grid, bench=bench, AT = transpose_a, BT = transpose_b, TYPE = dtype, TM = [64], TN = [128], TK = [8], **macros) return c
def backward(ctx, dy): # retrieve info x, gamma, beta, mean, var = ctx.saved_tensors eps = ctx.eps # allocate result dx = triton.empty(triton.shape(x), dtype=x.dtype) dgamma = triton.empty(triton.shape(gamma), dtype=gamma.dtype) dbeta = triton.empty(triton.shape(beta), dtype=beta.dtype) # execute C, H, W, B = triton.shape(x) _batchnorm.bwd_kernel(dx, dgamma, dbeta, dy, x, gamma, mean, var, H * W * B, eps, lambda opt: [1, C], TM=128) return dx, dgamma, dbeta, None
def forward(ctx, x, gamma, beta, eps): shape = triton.shape(x) dtype = x.dtype # allocate outputs C, H, W, B = shape[0], shape[1], shape[2], shape[3] y = triton.empty(shape, dtype=dtype) mean = triton.empty([C], dtype=dtype) var = triton.empty([C], dtype=dtype) # execute kernels _batchnorm.fwd_kernel(y, mean, var, x, gamma, beta, H * W * B, eps, lambda opt: [1, C], TM=128) # save ctx.save_for_backward(x, gamma, beta, mean, var) ctx.eps = eps return y
def _call(a, b, pad_d, pad_h, pad_w, stride_d, stride_h, stride_w, upsample_d, upsample_h, upsample_w, a_layout, b_layout, c_layout): # input shapes shape_a = list(triton.shape(a)) shape_b = list(triton.shape(b)) dim = len(shape_a) - 2 # indices an, ac, ad, ah, aw = [a_layout.find(x) for x in 'ncdhw'] bk, bc, bd, bh, bw = [b_layout.find(x) for x in 'kctrs'] cn, ck, cd, ch, cw = [c_layout.find(x) for x in 'nkdhw'] # extract shapes if dim == 2: shape_a.insert(ad, 1) if dim == 2: shape_b.insert(bd, 1) # output shape shape_c = [0] * 5 shape_c[cn] = shape_a[an] shape_c[ck] = shape_b[bk] shape_c[cd] = (shape_a[ad] * upsample_d - shape_b[bd] + 1 + 2 * pad_d + stride_d - 1) // stride_d shape_c[ch] = (shape_a[ah] * upsample_h - shape_b[bh] + 1 + 2 * pad_h + stride_h - 1) // stride_h shape_c[cw] = (shape_a[aw] * upsample_w - shape_b[bw] + 1 + 2 * pad_w + stride_w - 1) // stride_w # strides stride_a = _conv._extract_strides(shape_a) stride_b = _conv._extract_strides(shape_b) stride_c = _conv._extract_strides(shape_c) # tiling parameters TM = [32] TN = [32] TK = 8 # pointer deltas for a delta_a = _conv._delta_a(upsample_d, upsample_h, upsample_w, bc, bd, bh, bw, ac, ad, ah, aw, stride_a, shape_b, TK) delta_a = triton.fw.torch.from_numpy(delta_a).cuda() # delta increments for a inc_a = np.arange(delta_a.shape[-1] - TK, dtype=np.int32) inc_a = ((inc_a + TK) % inc_a.size) - inc_a inc_a = triton.fw.torch.from_numpy(inc_a).cuda() # allocate output if dim == 2: shape_c.pop(cd) c = triton.empty(shape_c, dtype=a.dtype) if dim == 2: shape_c.insert(cd, 1) # execute kernel trans_b = False is_wgrad = False is_blut = False macros = { 'UPAR': 'stride_h' if is_wgrad else '1', 'UPAS': '******' if is_wgrad else '1', 'UPAH': '' if is_wgrad else 'stride_h', 'UPAW': '' if is_wgrad else 'stride_w', 'LUT_SIZE': delta_a.shape[-1], 'TM': TM, 'TN': TN, 'TK': TK, 'A_TYPE': 'float', 'B_TYPE': 'float' } MATMUL_M = shape_c[cn] * shape_c[cd] * shape_c[ch] * shape_c[cw] MATMUL_N = shape_c[ck] MATMUL_K = shape_b[bc] * shape_b[bd] * shape_b[bh] * shape_b[bw] _conv.kernel( a, b, c, # matrix multiplication shapes MATMUL_M, MATMUL_N, MATMUL_K, # shapes for a shape_a[ah], shape_a[aw], # shapes for b shape_b[bh], shape_b[bw], # chapes for c shape_c[ch], shape_c[cw], shape_c[cn], # strides for a stride_a[an], stride_a[ac], stride_a[ad + 0], stride_a[ad + 1], stride_a[ad + 2], # strides for b stride_b[bc], stride_b[bd + 0], stride_b[bd + 1], stride_b[bd + 2], stride_b[bk], # strides for c stride_c[cn], stride_c[ck], stride_c[cd], stride_c[cd + 1], stride_c[cd + 2], # padding pad_h, pad_w, # striding stride_h, stride_w, # upsampling upsample_h, upsample_w, 0, 0, 0, 0, 0, 0, # look-up table delta_a, inc_a, lambda opt: [ triton.cdiv(MATMUL_M, opt.d('TM')), triton.cdiv(MATMUL_N, opt.d('TN')) ], **macros) return c