예제 #1
0
 def backward(ctx, dy):
     # lazy compilation of kernel
     if _batchnorm.bwd_kernel is None:
         _batchnorm.bwd_kernel = triton.kernel(bwd_src, defines={'TN': 128})
     # retrieve info
     x, gamma, beta, mean, var = ctx.saved_tensors
     eps = ctx.eps
     # allocate result
     dx = triton.empty(triton.shape(x), dtype=x.dtype)
     dgamma = triton.empty(triton.shape(gamma), dtype=gamma.dtype)
     dbeta = triton.empty(triton.shape(beta), dtype=beta.dtype)
     # execute
     C, H, W, B = triton.shape(x)
     _batchnorm.bwd_kernel(dx,
                           dgamma,
                           dbeta,
                           dy,
                           x,
                           gamma,
                           mean,
                           var,
                           H * W * B,
                           eps,
                           grid=lambda opt: [1, C])
     return dx, dgamma, dbeta, None
예제 #2
0
 def forward(ctx, x, gamma, beta, eps):
     # lazy compilation of kernel
     if _batchnorm.fwd_kernel is None:
         _batchnorm.fwd_kernel = triton.kernel(fwd_src, defines={'TM': 128})
     # shapes
     shape = triton.shape(x)
     dtype = x.dtype
     # allocate outputs
     C, H, W, B = shape[0], shape[1], shape[2], shape[3]
     y = triton.empty(shape, dtype=dtype)
     mean = triton.empty([C], dtype=dtype)
     var = triton.empty([C], dtype=dtype)
     # execute kernels
     _batchnorm.fwd_kernel(y,
                           mean,
                           var,
                           x,
                           gamma,
                           beta,
                           H * W * B,
                           eps,
                           grid=lambda opt: [1, C])
     # save
     ctx.save_for_backward(x, gamma, beta, mean, var)
     ctx.eps = eps
     return y
예제 #3
0
 def _call(a, b, transpose_a, transpose_b, bench):
   # extract shapes
   shape_a = triton.shape(a)
   shape_b = triton.shape(b)
   M, Ka = shape_a[0], shape_a[1]
   Kb, N = shape_b[0], shape_b[1]
   # transpose shapes
   if transpose_a:
     M, Ka = Ka, M
   if transpose_b:
     Kb, N = N, Kb
   # contiguous dimensions
   lda = M if transpose_a else Ka
   ldb = Kb if transpose_b else N
   ldc = N
   # data-type
   dtype = a.dtype
   # allocate output
   c = triton.empty([M, N], dtype = dtype)
   # compute
   grid = lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))]
   # macros -- not necessary but makes kernel source-code simpler
   macros = {# handle A transposition
             'USE_A'       : '^a'         if transpose_a else 'a',
             'STRIDE_AK'   : 'lda'        if transpose_a else '1',
             'STRIDE_AM'   : '1'          if transpose_a else 'lda',
             'BROADCAST_AK': ':, newaxis' if transpose_a else 'newaxis, :',
             'BROADCAST_AM': 'newaxis, :' if transpose_a else ':, newaxis',
             'SHAPE_A'     : 'TK, TM'     if transpose_a else 'TM, TK',
             # handle B transposition
             'USE_B'       : '^b'         if transpose_b else 'b',
             'STRIDE_BK'   : '1'          if transpose_b else 'ldb',
             'STRIDE_BN'   : 'ldb'        if transpose_b else '1',
             'BROADCAST_BK': 'newaxis, :' if transpose_b else ':, newaxis',
             'BROADCAST_BN': ':, newaxis' if transpose_b else 'newaxis, :',
             'SHAPE_B'     : 'TN, TK'     if transpose_b else 'TK, TN'}
   _dot.kernel(a, b, c, 1., M, N, Ka, lda, ldb, ldc, 
               grid, bench=bench,           
               AT = transpose_a, BT = transpose_b, TYPE = dtype, 
               TM = [64], TN = [128], TK = [8], **macros)
   return c
예제 #4
0
 def backward(ctx, dy):
     # retrieve info
     x, gamma, beta, mean, var = ctx.saved_tensors
     eps = ctx.eps
     # allocate result
     dx = triton.empty(triton.shape(x), dtype=x.dtype)
     dgamma = triton.empty(triton.shape(gamma), dtype=gamma.dtype)
     dbeta = triton.empty(triton.shape(beta), dtype=beta.dtype)
     # execute
     C, H, W, B = triton.shape(x)
     _batchnorm.bwd_kernel(dx,
                           dgamma,
                           dbeta,
                           dy,
                           x,
                           gamma,
                           mean,
                           var,
                           H * W * B,
                           eps,
                           lambda opt: [1, C],
                           TM=128)
     return dx, dgamma, dbeta, None
예제 #5
0
 def forward(ctx, x, gamma, beta, eps):
     shape = triton.shape(x)
     dtype = x.dtype
     # allocate outputs
     C, H, W, B = shape[0], shape[1], shape[2], shape[3]
     y = triton.empty(shape, dtype=dtype)
     mean = triton.empty([C], dtype=dtype)
     var = triton.empty([C], dtype=dtype)
     # execute kernels
     _batchnorm.fwd_kernel(y,
                           mean,
                           var,
                           x,
                           gamma,
                           beta,
                           H * W * B,
                           eps,
                           lambda opt: [1, C],
                           TM=128)
     # save
     ctx.save_for_backward(x, gamma, beta, mean, var)
     ctx.eps = eps
     return y
예제 #6
0
 def _call(a, b, pad_d, pad_h, pad_w, stride_d, stride_h, stride_w,
           upsample_d, upsample_h, upsample_w, a_layout, b_layout,
           c_layout):
     # input shapes
     shape_a = list(triton.shape(a))
     shape_b = list(triton.shape(b))
     dim = len(shape_a) - 2
     # indices
     an, ac, ad, ah, aw = [a_layout.find(x) for x in 'ncdhw']
     bk, bc, bd, bh, bw = [b_layout.find(x) for x in 'kctrs']
     cn, ck, cd, ch, cw = [c_layout.find(x) for x in 'nkdhw']
     # extract shapes
     if dim == 2:
         shape_a.insert(ad, 1)
     if dim == 2:
         shape_b.insert(bd, 1)
     # output shape
     shape_c = [0] * 5
     shape_c[cn] = shape_a[an]
     shape_c[ck] = shape_b[bk]
     shape_c[cd] = (shape_a[ad] * upsample_d - shape_b[bd] + 1 + 2 * pad_d +
                    stride_d - 1) // stride_d
     shape_c[ch] = (shape_a[ah] * upsample_h - shape_b[bh] + 1 + 2 * pad_h +
                    stride_h - 1) // stride_h
     shape_c[cw] = (shape_a[aw] * upsample_w - shape_b[bw] + 1 + 2 * pad_w +
                    stride_w - 1) // stride_w
     # strides
     stride_a = _conv._extract_strides(shape_a)
     stride_b = _conv._extract_strides(shape_b)
     stride_c = _conv._extract_strides(shape_c)
     # tiling parameters
     TM = [32]
     TN = [32]
     TK = 8
     # pointer deltas for a
     delta_a = _conv._delta_a(upsample_d, upsample_h, upsample_w, bc, bd,
                              bh, bw, ac, ad, ah, aw, stride_a, shape_b, TK)
     delta_a = triton.fw.torch.from_numpy(delta_a).cuda()
     # delta increments for a
     inc_a = np.arange(delta_a.shape[-1] - TK, dtype=np.int32)
     inc_a = ((inc_a + TK) % inc_a.size) - inc_a
     inc_a = triton.fw.torch.from_numpy(inc_a).cuda()
     # allocate output
     if dim == 2:
         shape_c.pop(cd)
     c = triton.empty(shape_c, dtype=a.dtype)
     if dim == 2:
         shape_c.insert(cd, 1)
     # execute kernel
     trans_b = False
     is_wgrad = False
     is_blut = False
     macros = {
         'UPAR': 'stride_h' if is_wgrad else '1',
         'UPAS': '******' if is_wgrad else '1',
         'UPAH': '' if is_wgrad else 'stride_h',
         'UPAW': '' if is_wgrad else 'stride_w',
         'LUT_SIZE': delta_a.shape[-1],
         'TM': TM,
         'TN': TN,
         'TK': TK,
         'A_TYPE': 'float',
         'B_TYPE': 'float'
     }
     MATMUL_M = shape_c[cn] * shape_c[cd] * shape_c[ch] * shape_c[cw]
     MATMUL_N = shape_c[ck]
     MATMUL_K = shape_b[bc] * shape_b[bd] * shape_b[bh] * shape_b[bw]
     _conv.kernel(
         a,
         b,
         c,
         # matrix multiplication shapes
         MATMUL_M,
         MATMUL_N,
         MATMUL_K,
         # shapes for a
         shape_a[ah],
         shape_a[aw],
         # shapes for b
         shape_b[bh],
         shape_b[bw],
         # chapes for c
         shape_c[ch],
         shape_c[cw],
         shape_c[cn],
         # strides for a
         stride_a[an],
         stride_a[ac],
         stride_a[ad + 0],
         stride_a[ad + 1],
         stride_a[ad + 2],
         # strides for b
         stride_b[bc],
         stride_b[bd + 0],
         stride_b[bd + 1],
         stride_b[bd + 2],
         stride_b[bk],
         # strides for c
         stride_c[cn],
         stride_c[ck],
         stride_c[cd],
         stride_c[cd + 1],
         stride_c[cd + 2],
         # padding
         pad_h,
         pad_w,
         # striding
         stride_h,
         stride_w,
         # upsampling
         upsample_h,
         upsample_w,
         0,
         0,
         0,
         0,
         0,
         0,
         # look-up table
         delta_a,
         inc_a,
         lambda opt: [
             triton.cdiv(MATMUL_M, opt.d('TM')),
             triton.cdiv(MATMUL_N, opt.d('TN'))
         ],
         **macros)
     return c