示例#1
0
 def forward(ctx, x, gamma, beta, eps):
     # lazy compilation of kernel
     if _batchnorm.fwd_kernel is None:
         _batchnorm.fwd_kernel = triton.kernel(fwd_src, defines={'TM': 128})
     # shapes
     shape = triton.shape(x)
     dtype = x.dtype
     # allocate outputs
     C, H, W, B = shape[0], shape[1], shape[2], shape[3]
     y = triton.empty(shape, dtype=dtype)
     mean = triton.empty([C], dtype=dtype)
     var = triton.empty([C], dtype=dtype)
     # execute kernels
     _batchnorm.fwd_kernel(y,
                           mean,
                           var,
                           x,
                           gamma,
                           beta,
                           H * W * B,
                           eps,
                           grid=lambda opt: [1, C])
     # save
     ctx.save_for_backward(x, gamma, beta, mean, var)
     ctx.eps = eps
     return y
示例#2
0
 def backward(ctx, dy):
     # lazy compilation of kernel
     if _batchnorm.bwd_kernel is None:
         _batchnorm.bwd_kernel = triton.kernel(bwd_src, defines={'TN': 128})
     # retrieve info
     x, gamma, beta, mean, var = ctx.saved_tensors
     eps = ctx.eps
     # allocate result
     dx = triton.empty(triton.shape(x), dtype=x.dtype)
     dgamma = triton.empty(triton.shape(gamma), dtype=gamma.dtype)
     dbeta = triton.empty(triton.shape(beta), dtype=beta.dtype)
     # execute
     C, H, W, B = triton.shape(x)
     _batchnorm.bwd_kernel(dx,
                           dgamma,
                           dbeta,
                           dy,
                           x,
                           gamma,
                           mean,
                           var,
                           H * W * B,
                           eps,
                           grid=lambda opt: [1, C])
     return dx, dgamma, dbeta, None
示例#3
0
 def forward(ctx, einsum, a, b, shape_c, **kwargs):
     bench = kwargs['bench'] if 'bench' in kwargs else False
     arrays = kwargs['arrays'] if 'arrays' in kwargs else dict()
     # allocate output
     dtype = a.dtype
     c = triton.empty(shape_c, dtype=dtype)
     # compile einsum instance
     cache = _einsum.instance_cache
     key = (einsum, dtype, 
            a.stride(), b.stride(), c.stride(), 
            a.shape, b.shape, c.shape)
     if key not in cache:
         cache[key] = _einsum.instance(einsum, dtype, 
                                       a.stride(), b.stride(), c.stride(),
                                       a.shape, b.shape, c.shape, arrays)
     instance = cache[key]
     instance.run(a, b, c, bench)
     # save information in context
     ctx.flops = instance.flops
     ctx.sym_a = instance.sym_a
     ctx.sym_b = instance.sym_b
     ctx.sym_c = instance.sym_c
     ctx.matmul_B = instance.matmul_B
     ctx.matmul_M = instance.matmul_M
     ctx.matmul_N = instance.matmul_N
     ctx.matmul_K = instance.matmul_K
     ctx.bench = bench
     ctx.save_for_backward(a, b)
     return c
示例#4
0
 def forward(ctx, x, gamma, beta, eps):
     shape = triton.shape(x)
     dtype = x.dtype
     # allocate outputs
     C, H, W, B = shape[0], shape[1], shape[2], shape[3]
     y = triton.empty(shape, dtype=dtype)
     mean = triton.empty([C], dtype=dtype)
     var = triton.empty([C], dtype=dtype)
     # execute kernels
     _batchnorm.fwd_kernel(y,
                           mean,
                           var,
                           x,
                           gamma,
                           beta,
                           H * W * B,
                           eps,
                           lambda opt: [1, C],
                           TM=128)
     # save
     ctx.save_for_backward(x, gamma, beta, mean, var)
     ctx.eps = eps
     return y
示例#5
0
 def backward(ctx, dy):
     # retrieve info
     x, gamma, beta, mean, var = ctx.saved_tensors
     eps = ctx.eps
     # allocate result
     dx = triton.empty(triton.shape(x), dtype=x.dtype)
     dgamma = triton.empty(triton.shape(gamma), dtype=gamma.dtype)
     dbeta = triton.empty(triton.shape(beta), dtype=beta.dtype)
     # execute
     C, H, W, B = triton.shape(x)
     _batchnorm.bwd_kernel(dx,
                           dgamma,
                           dbeta,
                           dy,
                           x,
                           gamma,
                           mean,
                           var,
                           H * W * B,
                           eps,
                           lambda opt: [1, C],
                           TM=128)
     return dx, dgamma, dbeta, None
示例#6
0
 def _call(a, b):
     # create kernel if necessary
     dtype = a.dtype
     if dtype not in _dot.kernel:
         defines = {
             'TYPE': dtype,
             'STRIDE_AM': '1',
             'STRIDE_AK': 'lda',
             'STRIDE_BN': '1',
             'STRIDE_BK': 'ldb',
             'TM': [64, 128],
             'TN': [64, 128],
             'TK': [8, 16],
             'TZ': [1]
         }
         _dot.kernel[dtype] = triton.kernel(_dot.src,
                                            num_warps=[4],
                                            defines=defines)
     kernel = _dot.kernel[dtype]
     # allocate output
     M, K = a.shape
     K, N = b.shape
     c = triton.empty([M, N], dtype=dtype)
     # enqueue
     grid = lambda opt: [
         triton.cdiv(M, opt.d('TM')),
         triton.cdiv(N, opt.d('TN'))
     ]
     time = kernel(a,
                   b,
                   c,
                   1.,
                   M,
                   N,
                   K,
                   a.stride(0),
                   b.stride(0),
                   c.stride(0),
                   grid=grid,
                   bench=100)
     print(2 * M * N * K / (time * 1e-6) * 1e-9)
     return c
示例#7
0
 def _call(a, b, transpose_a, transpose_b, bench):
   # extract shapes
   shape_a = triton.shape(a)
   shape_b = triton.shape(b)
   M, Ka = shape_a[0], shape_a[1]
   Kb, N = shape_b[0], shape_b[1]
   # transpose shapes
   if transpose_a:
     M, Ka = Ka, M
   if transpose_b:
     Kb, N = N, Kb
   # contiguous dimensions
   lda = M if transpose_a else Ka
   ldb = Kb if transpose_b else N
   ldc = N
   # data-type
   dtype = a.dtype
   # allocate output
   c = triton.empty([M, N], dtype = dtype)
   # compute
   grid = lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))]
   # macros -- not necessary but makes kernel source-code simpler
   macros = {# handle A transposition
             'USE_A'       : '^a'         if transpose_a else 'a',
             'STRIDE_AK'   : 'lda'        if transpose_a else '1',
             'STRIDE_AM'   : '1'          if transpose_a else 'lda',
             'BROADCAST_AK': ':, newaxis' if transpose_a else 'newaxis, :',
             'BROADCAST_AM': 'newaxis, :' if transpose_a else ':, newaxis',
             'SHAPE_A'     : 'TK, TM'     if transpose_a else 'TM, TK',
             # handle B transposition
             'USE_B'       : '^b'         if transpose_b else 'b',
             'STRIDE_BK'   : '1'          if transpose_b else 'ldb',
             'STRIDE_BN'   : 'ldb'        if transpose_b else '1',
             'BROADCAST_BK': 'newaxis, :' if transpose_b else ':, newaxis',
             'BROADCAST_BN': ':, newaxis' if transpose_b else 'newaxis, :',
             'SHAPE_B'     : 'TN, TK'     if transpose_b else 'TK, TN'}
   _dot.kernel(a, b, c, 1., M, N, Ka, lda, ldb, ldc, 
               grid, bench=bench,           
               AT = transpose_a, BT = transpose_b, TYPE = dtype, 
               TM = [64], TN = [128], TK = [8], **macros)
   return c
示例#8
0
 def forward(ctx, a, b, pad, stride, time):
     # create kernel if necessary
     dtype = a.dtype
     # shapes
     Z, CI, H, W = a.shape
     _, R, S, CO = b.shape
     P = (H + 2 * pad[0] - R) // stride[0] + 1
     Q = (W + 2 * pad[1] - S) // stride[1] + 1
     # compile kernel
     if dtype not in _conv.kernel:
         TK = 8
         defines = {
             'TYPE': dtype,
             'TM': [16, 32, 64, 128],
             'TN': [16, 32, 64, 128],
             'TK': [TK],
             'TZ': [1],
             'HH': H,
             'WW': W,
             'PP': P,
             'QQ': Q,
             'SS': S,
             'RR': R,
         }
         idx = torch.arange(CI * R * S)
         ci, r, s = _conv.unpack(idx, CI, R, S)
         nci, nr, ns = _conv.unpack(idx + TK, CI, R, S)
         delta = (nci - ci) * a.stride(1) + (nr - r) * a.stride(2) + (
             ns - s) * a.stride(3)
         delta = delta.type(torch.int32).cuda()
         _conv.kernel[dtype] = (delta,
                                triton.kernel(_conv.src,
                                              num_warps=[2, 4],
                                              defines=defines))
     delta, kernel = _conv.kernel[dtype]
     # allocate output
     c = triton.empty([Z, CO, P, Q], dtype=dtype)
     # enqueue
     grid = lambda opt: [
         triton.cdiv(Z * P * Q, opt.d('TM')),
         triton.cdiv(CO, opt.d('TN'))
     ]
     time[0] = kernel(a,
                      b,
                      c,
                      1.,
                      Z * P * Q,
                      CO,
                      CI * R * S,
                      pad[0],
                      pad[1],
                      stride[0],
                      stride[1],
                      delta,
                      a.stride(0),
                      a.stride(1),
                      a.stride(2),
                      a.stride(3),
                      b.stride(0),
                      b.stride(1),
                      b.stride(2),
                      b.stride(3),
                      c.stride(0),
                      c.stride(1),
                      c.stride(2),
                      c.stride(3),
                      grid=grid,
                      bench=100)
     return c
示例#9
0
 def _call(a, b, pad_d, pad_h, pad_w, stride_d, stride_h, stride_w,
           upsample_d, upsample_h, upsample_w, a_layout, b_layout,
           c_layout):
     # input shapes
     shape_a = list(triton.shape(a))
     shape_b = list(triton.shape(b))
     dim = len(shape_a) - 2
     # indices
     an, ac, ad, ah, aw = [a_layout.find(x) for x in 'ncdhw']
     bk, bc, bd, bh, bw = [b_layout.find(x) for x in 'kctrs']
     cn, ck, cd, ch, cw = [c_layout.find(x) for x in 'nkdhw']
     # extract shapes
     if dim == 2:
         shape_a.insert(ad, 1)
     if dim == 2:
         shape_b.insert(bd, 1)
     # output shape
     shape_c = [0] * 5
     shape_c[cn] = shape_a[an]
     shape_c[ck] = shape_b[bk]
     shape_c[cd] = (shape_a[ad] * upsample_d - shape_b[bd] + 1 + 2 * pad_d +
                    stride_d - 1) // stride_d
     shape_c[ch] = (shape_a[ah] * upsample_h - shape_b[bh] + 1 + 2 * pad_h +
                    stride_h - 1) // stride_h
     shape_c[cw] = (shape_a[aw] * upsample_w - shape_b[bw] + 1 + 2 * pad_w +
                    stride_w - 1) // stride_w
     # strides
     stride_a = _conv._extract_strides(shape_a)
     stride_b = _conv._extract_strides(shape_b)
     stride_c = _conv._extract_strides(shape_c)
     # tiling parameters
     TM = [32]
     TN = [32]
     TK = 8
     # pointer deltas for a
     delta_a = _conv._delta_a(upsample_d, upsample_h, upsample_w, bc, bd,
                              bh, bw, ac, ad, ah, aw, stride_a, shape_b, TK)
     delta_a = triton.fw.torch.from_numpy(delta_a).cuda()
     # delta increments for a
     inc_a = np.arange(delta_a.shape[-1] - TK, dtype=np.int32)
     inc_a = ((inc_a + TK) % inc_a.size) - inc_a
     inc_a = triton.fw.torch.from_numpy(inc_a).cuda()
     # allocate output
     if dim == 2:
         shape_c.pop(cd)
     c = triton.empty(shape_c, dtype=a.dtype)
     if dim == 2:
         shape_c.insert(cd, 1)
     # execute kernel
     trans_b = False
     is_wgrad = False
     is_blut = False
     macros = {
         'UPAR': 'stride_h' if is_wgrad else '1',
         'UPAS': '******' if is_wgrad else '1',
         'UPAH': '' if is_wgrad else 'stride_h',
         'UPAW': '' if is_wgrad else 'stride_w',
         'LUT_SIZE': delta_a.shape[-1],
         'TM': TM,
         'TN': TN,
         'TK': TK,
         'A_TYPE': 'float',
         'B_TYPE': 'float'
     }
     MATMUL_M = shape_c[cn] * shape_c[cd] * shape_c[ch] * shape_c[cw]
     MATMUL_N = shape_c[ck]
     MATMUL_K = shape_b[bc] * shape_b[bd] * shape_b[bh] * shape_b[bw]
     _conv.kernel(
         a,
         b,
         c,
         # matrix multiplication shapes
         MATMUL_M,
         MATMUL_N,
         MATMUL_K,
         # shapes for a
         shape_a[ah],
         shape_a[aw],
         # shapes for b
         shape_b[bh],
         shape_b[bw],
         # chapes for c
         shape_c[ch],
         shape_c[cw],
         shape_c[cn],
         # strides for a
         stride_a[an],
         stride_a[ac],
         stride_a[ad + 0],
         stride_a[ad + 1],
         stride_a[ad + 2],
         # strides for b
         stride_b[bc],
         stride_b[bd + 0],
         stride_b[bd + 1],
         stride_b[bd + 2],
         stride_b[bk],
         # strides for c
         stride_c[cn],
         stride_c[ck],
         stride_c[cd],
         stride_c[cd + 1],
         stride_c[cd + 2],
         # padding
         pad_h,
         pad_w,
         # striding
         stride_h,
         stride_w,
         # upsampling
         upsample_h,
         upsample_w,
         0,
         0,
         0,
         0,
         0,
         0,
         # look-up table
         delta_a,
         inc_a,
         lambda opt: [
             triton.cdiv(MATMUL_M, opt.d('TM')),
             triton.cdiv(MATMUL_N, opt.d('TN'))
         ],
         **macros)
     return c