Пример #1
0
    def forward(ctx, x):

        M, N = x.shape

        ldx = N

        dtype = x.dtype

        y = torch.empty((M, N)).cuda()

        defines = {
            'TYPE': dtype,
            'TM': [32, 64, 128],
            'TN': [32, 64, 128],
        }
        grid = lambda opt: [
            triton.cdiv(M, opt.d('TM')),
            triton.cdiv(N, opt.d('TN'))
        ]

        if _copy.kernel is None:
            _copy.kernel = triton.kernel(_copy.src,
                                         defines=defines,
                                         num_warps=[4])

        _copy.kernel(x, y, M, N, ldx, grid=grid)

        return y
Пример #2
0
def make_kernel(device, dtype):
    key = (device, dtype)
    cache = make_kernel.cache
    if key not in cache:
        defines = {'TYPE': dtype}
        cache[key] = triton.kernel(src, device=device, defines=defines, autotune_vals=autotune_configs, autotune_key=autotune_key)
    return cache[key]
Пример #3
0
 def forward(ctx, x, gamma, beta, eps):
     # lazy compilation of kernel
     if _batchnorm.fwd_kernel is None:
         _batchnorm.fwd_kernel = triton.kernel(fwd_src, defines={'TM': 128})
     # shapes
     shape = triton.shape(x)
     dtype = x.dtype
     # allocate outputs
     C, H, W, B = shape[0], shape[1], shape[2], shape[3]
     y = triton.empty(shape, dtype=dtype)
     mean = triton.empty([C], dtype=dtype)
     var = triton.empty([C], dtype=dtype)
     # execute kernels
     _batchnorm.fwd_kernel(y,
                           mean,
                           var,
                           x,
                           gamma,
                           beta,
                           H * W * B,
                           eps,
                           grid=lambda opt: [1, C])
     # save
     ctx.save_for_backward(x, gamma, beta, mean, var)
     ctx.eps = eps
     return y
Пример #4
0
class _add(torch.autograd.Function):
    src = """
__global__ void add(float* z, float* x, float* y, int N) {

    int pid = get_program_id(0);

    int offset[TILE] = pid * TILE + 0 ... TILE;
    float* pz[TILE]  = z + offset;
    float* px[TILE]  = x + offset;
    float* py[TILE]  = y + offset;

    bool check[TILE] = offset < N;

    *?(check)pz = *?(check)px + *?(check)py;
}
    """

    kernel = triton.kernel(src, defines={'TILE': 1024}, num_warps=[4])

    @staticmethod
    def forward(ctx, x, y):
        z = torch.empty_like(x).cuda()

        N = x.numel()
        grid = lambda opt: (triton.cdiv(N, opt.d('TILE')), )

        _add.kernel(z, x, y, N, grid=grid)

        return z
Пример #5
0
def get_kernel(block, dtype, device):
    key = (block, dtype, device)
    if key not in kernels:
        src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'))
        defines = {'BLOCK': block, 'TYPE': dtype}
        kernels[key] = triton.kernel(src, device=device, defines=defines)
    return kernels[key]
Пример #6
0
        def forward(cls, ctx, logits, indices):
            n_vocab = logits.shape[-1]
            assert indices.dtype == torch.int64
            assert (
                n_vocab %
                128 == 0), "Number of logit options must be divisible by 128."

            if not (logits.dtype, n_vocab) in cls.input_config_to_kernel_fwd:
                cls.input_config_to_kernel_fwd[(logits.dtype,
                                                n_vocab)] = triton.kernel(
                                                    cls.fwd_src,
                                                    device=logits.device,
                                                    defines={
                                                        "TILE": n_vocab,
                                                        "TYPE": logits.dtype,
                                                    },
                                                    num_warps=[4],
                                                )
            kernel_fwd = cls.input_config_to_kernel_fwd[(logits.dtype,
                                                         n_vocab)]

            result = torch.zeros_like(indices, dtype=logits.dtype).cuda()
            grid = lambda opt: (logits.shape[0], )
            kernel_fwd(logits.data_ptr(),
                       indices.data_ptr(),
                       result.data_ptr(),
                       grid=grid)
            # logits -> neg_logprobs via an in place modification by kernel_fwd
            ctx.save_for_backward(logits, indices)

            return result
Пример #7
0
 def make_kernel(cache, src, max_k, dtype, block, apply_scale, apply_rpe, apply_kp_mask, apply_attn_mask, kp_mask_mode, attn_mask_mode):
     if max_k >= 32768:
       raise NotImplementedError('Reductions larger than 32768 elements '\
                                 'are not yet implemented')
     num_warps = 4 if max_k < 512 else (8 if max_k < 2048 else 16)
     pad = num_warps * 32 * 2
     TN = (int(max_k) + pad-1)//pad * pad
     # just-in-time compile kernel
     key = (block, dtype, num_warps, TN, apply_scale, apply_rpe, apply_kp_mask, apply_attn_mask, kp_mask_mode, attn_mask_mode)
     if key not in cache:
         defines = {'TM': [1], 'TN': [TN], 'TYPE': dtype, 'BLOCK': block,
                    'INFINITY': {torch.float32: 'F32_INFINITY',
                                 torch.float16: 'F16_INFINITY'}[dtype]}
         if apply_scale:
             defines['APPLY_SCALE'] = True
         if apply_rpe:
             defines['APPLY_RPE'] = True
         if apply_kp_mask:
             defines['APPLY_KP_MASK'] = True
             if kp_mask_mode == 'mul':
                 defines['KP_MASK_MUL'] = True
         if apply_attn_mask:
             defines['APPLY_ATTN_MASK'] = True
             if attn_mask_mode == 'mul':
                 defines['ATTN_MASK_MUL'] = True
         kernel  = triton.kernel(src, defines=defines, num_warps=[num_warps])
         cache[key] = kernel
     return cache[key]
Пример #8
0
 def make_kernel(cache, src, max_k, dtype, block, kp_mask_mode, attn_mask_mode):
     # pad tile to cover the entire reduction
     params = {16384: (1, 32768, 16),
               8192:  (1, 16384, 16),
               4096:  (1, 8192, 16),
               2048:  (1, 4096, 16),
               1024:  (1, 2048, 16),
               512:   (1, 1024, 8),
               256:   (1, 512, 4),
               128:   (1, 256, 4)}
     bound = max(128, 2**int(math.log2(max_k-1)))
     if bound not in params:
       raise NotImplementedError('Reductions larger than 32768 elements '\
                                 'are not yet implemented')
     TM, TN, num_warps = params[bound]
     # just-in-time compile kernel
     key = (dtype, TM, TN, num_warps, kp_mask_mode, attn_mask_mode)
     if key not in cache:
         defines = {'TM': [TM], 'TN': [TN], 'TYPE': dtype, 'BLOCK': block,
                    'INFINITY': {torch.float32: 'F32_INFINITY',
                                 torch.float16: 'F16_INFINITY'}[dtype]}
         if kp_mask_mode == 'mul':
             defines['KP_MASK_MUL'] = True
         if attn_mask_mode == 'mul':
             defines['ATTN_MASK_MUL'] = True
         kernel  = triton.kernel(src, defines=defines, num_warps=[num_warps])
         cache[key] = kernel
     return cache[key]
Пример #9
0
 def backward(ctx, dy):
     # lazy compilation of kernel
     if _batchnorm.bwd_kernel is None:
         _batchnorm.bwd_kernel = triton.kernel(bwd_src, defines={'TN': 128})
     # retrieve info
     x, gamma, beta, mean, var = ctx.saved_tensors
     eps = ctx.eps
     # allocate result
     dx = triton.empty(triton.shape(x), dtype=x.dtype)
     dgamma = triton.empty(triton.shape(gamma), dtype=gamma.dtype)
     dbeta = triton.empty(triton.shape(beta), dtype=beta.dtype)
     # execute
     C, H, W, B = triton.shape(x)
     _batchnorm.bwd_kernel(dx,
                           dgamma,
                           dbeta,
                           dy,
                           x,
                           gamma,
                           mean,
                           var,
                           H * W * B,
                           eps,
                           grid=lambda opt: [1, C])
     return dx, dgamma, dbeta, None
Пример #10
0
 def _dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs):
     # shapes / dtypes
     AS0 = spdims[0]
     AS1 = block * spdims[2 if trans_a else 1]
     AS2 = block * spdims[1 if trans_a else 2]
     BS0 = b.size(0)
     BS1 = b.size(1)
     BS2 = b.size(3 if trans_b else 2)
     BS3 = b.size(2 if trans_b else 3)
     dtype = a.dtype
     # kernel
     key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c)
     if key not in _matmul.dsd_cache:
         defines = {
             'TM': block, 'TN': 128, 'TK': 16, 'BLOCK': block, 'TYPE': dtype, 'STRIDE_AM': 1 if trans_a else block,
             'STRIDE_AK': block if trans_a else 1, 'STRIDE_BN': 'ldb' if trans_b else '1', 'STRIDE_BK':
             '1' if trans_b else 'ldb', 'STRIDE_CM': '1' if trans_c else 'ldc', 'STRIDE_CN':
             'ldc' if trans_c else '1', 'NAME': 'dsd_kernel', 'DSD': True
         }
         _matmul.dsd_cache[key] = triton.kernel(src, device=a.device, defines=defines)
     kernel = _matmul.dsd_cache[key]
     # output
     CS0 = BS0
     CS1 = BS1
     CS2 = BS3 if trans_c else AS1
     CS3 = AS1 if trans_c else BS3
     locks = _matmul.get_locks(2 * BS0 * BS3 // 32 * num_locks, a.device)
     c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
     kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), block, b.stride(2), c.stride(2), a.stride(0), b.stride(0),
            c.stride(0), a.stride(1), b.stride(1), c.stride(1), BS3, AS1, 0, 0, lut.data_ptr(), locks.data_ptr(),
            num_locks, grid=lambda opt: [width, triton.cdiv(BS3, opt.TN), BS0])
     return c
Пример #11
0
  def _sdd_matmul(a, b, trans_a, trans_b, trans_c,
                  spdims, block, luts, num_locks, widths, packs):
                  
    if trans_c:
      a, b = b, a
      trans_a, trans_b = not trans_b, not trans_a
    AS0 = a.size(0)
    AS1 = a.size(1)
    AS2 = a.size(3 if trans_a else 2)
    AS3 = a.size(2 if trans_a else 3)
    BS0 = b.size(0)
    BS1 = b.size(1)
    BS2 = b.size(3 if trans_b else 2)
    BS3 = b.size(2 if trans_b else 3)
    dtype = a.dtype
    device = a.device
    is_16_multiple = AS3 % 16 == 0
    is_32_multiple = AS3 % 32 == 0
    is_64_multiple = AS3 % 64 == 0
    if not is_16_multiple:
      raise ValueError('Reduction size for SDD must be a multiple of 16')
    # create kernel
    total_width = sum([width*pack*pack for width,pack in zip(widths, packs)])
    c = torch.empty((AS0, total_width, block, block), dtype=dtype, device=device)
    for lut, width, pack in zip(luts, widths, packs):
      num_lock = 1
      key = (block, device, a.dtype, b.dtype, trans_a, trans_b, trans_c, pack, is_32_multiple, is_64_multiple)
      if key not in _matmul.sdd_cache:
        defines = {'TM': block*pack, 'TN': block*pack,
                    'TMN': block*block*pack*pack, 
                    'BLOCK': block, 
                    'TK': 32, 
                    'TYPE': dtype,
                    'STRIDE_AM': '1'    if trans_a else 'lda', 
                    'STRIDE_AK': 'lda'  if trans_a else '1',
                    'STRIDE_BN': 'ldb'  if trans_b else '1', 
                    'STRIDE_BK': '1'    if trans_b else 'ldb',
                    'STRIDE_CM': 'ldc', 'STRIDE_CN': '1',
                    'SDD': True, 'TZ': 1, 'NAME': 'sdd_kernel'}
        _matmul.sdd_cache[key] = triton.kernel(src, device=device, defines=defines)

      kernel = _matmul.sdd_cache[key]
      # create output
      locks = _matmul.get_locks(2*width*AS0*num_lock, a.device)
      # maximum grid size is 65535
      # so operation might be decomposed into multiple
      # kernel calls
      max_width = 49152
      for off_width in range(0, width, max_width):
        kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), 
               a.stride(2), b.stride(2), block, 
               a.stride(0), b.stride(0), c.stride(0),
               a.stride(1), b.stride(1), c.stride(0), 
               AS2, AS2, AS3, off_width, lut.data_ptr(), locks.data_ptr(), num_lock, 
               grid = lambda opt: [opt.TZ, min(max_width, width - off_width), AS0])
    # save for backward pass
    return c
Пример #12
0
        def forward(cls, ctx, logits, indices):
            # make sure we can use triton
            n_vocab = logits.shape[-1]
            assert (indices.dtype == torch.int64
                    ), "Indices are expected to be of type long."

            # compile a new kernel if needed; otherwise load from a cache
            if not (logits.dtype, n_vocab) in cls.input_config_to_kernel_fwd:
                infinities = {
                    torch.float16: "F16_INFINITY",
                    torch.float32: "F32_INFINITY",
                }
                cls.input_config_to_kernel_fwd[(
                    logits.dtype, n_vocab)] = triton.kernel(
                        cls.fwd_src,
                        device=logits.device,
                        defines={
                            "TILE": make_power_of_two(n_vocab),
                            "TYPE": logits.dtype,
                            "INFINITY": infinities[logits.dtype],
                        },
                    )
            kernel_fwd = cls.input_config_to_kernel_fwd[(logits.dtype,
                                                         n_vocab)]

            # flatten logits and be prepared to restore them to their original shape
            original_logits_shape = logits.shape
            if len(original_logits_shape) > 2:
                logits = logits.reshape((-1, n_vocab))
                indices = indices.reshape((-1, ))

            # run the kernel and assign the result in place
            result = torch.empty_like(indices, dtype=logits.dtype).cuda()
            neg_logprobs = torch.empty_like(logits, dtype=logits.dtype).cuda()
            grid = lambda opt: (logits.shape[0], )
            kernel_fwd(
                logits.data_ptr(),
                neg_logprobs.data_ptr(),
                indices.data_ptr(),
                result.data_ptr(),
                n_vocab,
                grid=grid,
            )

            if len(original_logits_shape) > 2:
                logits = logits.reshape(original_logits_shape)
                indices = indices.reshape(*original_logits_shape[:-1])

            ctx.save_for_backward(neg_logprobs, indices)
            ctx.original_logits_shape = original_logits_shape

            return result
Пример #13
0
 def forward(ctx, x, scale, bias, res):
   if x.dtype not in _relu.fwd_kernel:
     defines = {'TYPE': x.dtype, 'TN': [128]}
     _relu.fwd_kernel[x.dtype] = triton.kernel(_relu.fwd_src, defines=defines, num_warps=[4])
   kernel = _relu.fwd_kernel[x.dtype]
   # launch kernel
   y = torch.empty_strided(x.shape, x.stride(), device=x.device, dtype=x.dtype)
   N = x.numel()
   grid = lambda opt: [triton.cdiv(N, opt.d('TN'))]
   kernel(x, y, scale.item(), bias.item(),res, N, grid=grid)
   # update context
   ctx.save_for_backward(x, y)
   ctx.scale = scale
   return y
Пример #14
0
 def _sdd_matmul(a, b, trans_a, trans_b, trans_c,
                 spdims, block, lut, num_locks, width, 
                 bench, time):
   if trans_c:
     a, b = b, a
     trans_a, trans_b = not trans_b, not trans_a
   AS0 = a.size(0)
   AS1 = a.size(1)
   AS2 = a.size(3 if trans_a else 2)
   AS3 = a.size(2 if trans_a else 3)
   BS0 = b.size(0)
   BS1 = b.size(1)
   BS2 = b.size(3 if trans_b else 2)
   BS3 = b.size(2 if trans_b else 3)
   dtype = a.dtype
   # create kernel
   key = (block, a.dtype, b.dtype, trans_a, trans_b, trans_c)
   if key not in _sparse_matmul.sdd_cache:
     defines =  {'TM': block, 'TN': block, 'TK': 16, 'TYPE': dtype,
                 'STRIDE_AM': '1'    if trans_a else 'lda', 
                 'STRIDE_AK': 'lda'  if trans_a else '1',
                 'STRIDE_BN': 'ldb'  if trans_b else '1', 
                 'STRIDE_BK': '1'    if trans_b else 'ldb',
                 'STRIDE_CM': 'ldc', 
                 'STRIDE_CN': '1',
                 'SDD': True, 'TZ': 1, 'NAME': 'sdd_kernel'}
     _sparse_matmul.sdd_cache[key] = triton.kernel(src, defines=defines, num_warps=[1, 2, 4])
   kernel = _sparse_matmul.sdd_cache[key]
   # create output
   locks = _sparse_matmul.get_locks(2*width*AS0*num_locks)
   c = torch.empty((AS0, width, block, block), dtype=dtype, device=a.device)
   # maximum grid size is 65535
   # so operation might be decomposed into multiple
   # kernel calls
   max_width = 49152
   total = 0 if bench else None
   for off_width in range(0, width, max_width):
     current  = kernel(a, b, c, 
                       a.stride(2), b.stride(2), block, 
                       a.stride(0), b.stride(0), c.stride(0),
                       a.stride(1), b.stride(1), c.stride(0), 
                       AS2, AS3, off_width, lut, locks, num_locks, 
                       grid = lambda opt: [opt.d('TZ'), min(max_width, width - off_width), AS0], 
                       bench = bench)
     total = total + current if bench else None
   time[0] = total
   # save for backward pass
   return c
Пример #15
0
 def _dsd_matmul(a, b, trans_a, trans_b, trans_c,
                 spdims, block, lut, num_locks, width, packs,
                 bench, time):
   # shapes / dtypes
   AS0 = spdims[0]
   AS1 = block * spdims[2 if trans_a else 1]
   AS2 = block * spdims[1 if trans_a else 2]
   BS0 = b.size(0)
   BS1 = b.size(1)
   BS2 = b.size(3 if trans_b else 2)
   BS3 = b.size(2 if trans_b else 3)
   dtype = a.dtype
   # kernel
   key = (block, a.dtype, b.dtype, trans_a, trans_b, trans_c)
   if key not in _sparse_matmul.dsd_cache:
     TN = [64, 128] if dtype == torch.float32 else [64, 128]
     TK = [8]       if dtype == torch.float32 else [16]
     defines = {'TM': block, 'TN': TN, 'TK': TK, 
                'BLOCK': block,
                'TYPE': dtype,
                'STRIDE_AM': 1 if trans_a else block, 
                'STRIDE_AK': block if trans_a else 1,
                'STRIDE_BN': 'ldb' if trans_b else '1',
                'STRIDE_BK': '1' if trans_b else 'ldb',
                'STRIDE_CM': '1' if trans_c else 'ldc',
                'STRIDE_CN': 'ldc' if trans_c else '1',
                'NAME': 'dsd_kernel',
                'DSD': True}
     _sparse_matmul.dsd_cache[key] = triton.kernel(src, defines=defines, num_warps=[4])
   kernel = _sparse_matmul.dsd_cache[key]
   # output
   CS0 = BS0
   CS1 = BS1
   CS2 = BS3 if trans_c else AS1
   CS3 = AS1 if trans_c else BS3
   locks = _sparse_matmul.get_locks(2*BS0*BS3//32*num_locks, a.device)
   c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
   time[0] = kernel(a, b, c, 
                    block, b.stride(2), c.stride(2), 
                    a.stride(0), b.stride(0), c.stride(0),
                    a.stride(1), b.stride(1), c.stride(1),
                    BS3, AS1, 0, 0, lut, locks, num_locks, 
                    grid = lambda opt: [width, triton.cdiv(BS3, opt.d('TN')), BS0], 
                    bench = bench)
   return c
Пример #16
0
 def backward(ctx, dy):
   # load from context
   x, y = ctx.saved_tensors
   # get kernel
   if x.dtype not in _relu.bwd_kernel:
     defines = {'TYPE': x.dtype, 'TN': [128]}
     _relu.bwd_kernel[x.dtype] = triton.kernel(_relu.bwd_src, defines=defines, num_warps=[4])
   kernel = _relu.bwd_kernel[x.dtype]
   # allocate output
   dx = torch.empty_strided(x.shape, x.stride(), device=x.device, dtype=x.dtype)
   dres = torch.empty_strided(x.shape, x.stride(), device=x.device, dtype=x.dtype)
   dscale = torch.zeros((1,), device=dy.device, dtype=torch.float32)
   dbias = torch.zeros_like(dscale)
   # launch kernel
   N = x.numel()
   grid = lambda opt: [triton.cdiv(N, opt.d('TN'))]
   kernel(x, y, ctx.scale.item(), dx, dy, dscale, dbias, dres, N, grid=grid)
   return dx, dscale.type(x.dtype), dbias.type(x.dtype), dres
Пример #17
0
        def backward(cls, ctx, dneg_logprobs):
            """We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k]
            so we initialize the gradient as neg_logprobs, so we can just exponentiate
            to get p[k], which is most of what we need...  neg_logprobs will be
            modified in place to become the gradient we want
            """
            # load saved tensors and ensure correct types
            neg_logprobs, indices = ctx.saved_tensors
            original_logits_shape = ctx.original_logits_shape
            assert (
                dneg_logprobs.dtype == neg_logprobs.dtype
            ), f"Backward flowing derivatives of type {dneg_logprobs.dtype} != logits type {neg_logprobs.dtype}"
            n_vocab = neg_logprobs.shape[-1]

            # generate or load kernel
            if not (neg_logprobs.dtype,
                    n_vocab) in cls.input_config_to_kernel_bwd:
                cls.input_config_to_kernel_bwd[(
                    neg_logprobs.dtype, n_vocab)] = triton.kernel(
                        cls.bwd_src,
                        device=neg_logprobs.device,
                        defines={
                            "TILE": make_power_of_two(n_vocab),
                            "TYPE": neg_logprobs.dtype,
                        },
                    )
            kernel_bwd = cls.input_config_to_kernel_bwd[(neg_logprobs.dtype,
                                                         n_vocab)]
            grid = lambda opt: (neg_logprobs.shape[0], )

            # neg_logprobs will be modified in place to become our gradient:
            kernel_bwd(
                neg_logprobs.data_ptr(),
                indices.data_ptr(),
                dneg_logprobs.data_ptr(),
                n_vocab,
                grid=grid,
            )

            # reshape results based on shape of original logits passed to forward
            if len(original_logits_shape) > 2:
                neg_logprobs = neg_logprobs.reshape(original_logits_shape)

            return neg_logprobs, torch.zeros_like(indices)
Пример #18
0
def make_kernel(N, device):
    cache = make_kernel.cache
    # Now are kernels are indexed not only by the provided device but also
    # by the rounded number of columns in the input matrix
    BLOCK = next_power_of_2(N)
    # Another trick we can use is to ask the compiler to parallelize each
    # row-normalization more aggressively -- i.e., with more warps -- vectors
    # that are longer
    # You will see in the next tutorial how to auto-tune this value in a more natural
    # way so you don't have to come up with manual heuristics yourself
    num_warps = 4
    if BLOCK >= 2048: num_warps = 8
    if BLOCK >= 4096: num_warps = 16
    # Each (BLOCK, num_warps, device) results in a different kernel
    key = (BLOCK, num_warps, device)
    if key not in cache:
        defines = {'BLOCK': BLOCK}
        cache[key] = triton.kernel(_src, device=device, defines=defines, num_warps=num_warps)
    return cache[key]
Пример #19
0
 def _dds_matmul(a, b, trans_a, trans_b, trans_c,
             spdims, block, lut, num_locks, width, 
             bench, time):
   # shapes / dtypes
   AS0 = a.size(0)
   AS1 = a.size(1)
   AS2 = a.size(3 if trans_a else 2)
   AS3 = a.size(2 if trans_a else 3)
   BS0 = spdims[0]
   BS1 = block * spdims[2 if trans_b else 1]
   BS2 = block * spdims[1 if trans_b else 2]
   dtype = a.dtype
   # kernel
   key = (block, a.dtype, b.dtype, trans_a, trans_b, trans_c)
   if key not in _sparse_matmul.dds_cache:
     defines = {'TM': 128, 'TN': block, 'TK': 8, 
                'TYPE': dtype,
                'STRIDE_AM': 1 if trans_a else 'lda',
                'STRIDE_AK': 'lda' if trans_a else 1,
                'STRIDE_BN': block if trans_b else 1, 
                'STRIDE_BK': 1 if trans_b else block,
                'STRIDE_CM': '1' if trans_c else 'ldc',
                'STRIDE_CN': 'ldc' if trans_c else '1',
                'NAME': 'dds_kernel',
                'DDS': True}
     _sparse_matmul.dds_cache[key] = triton.kernel(src, defines=defines, num_warps=[4])
   kernel = _sparse_matmul.dds_cache[key]
   # output
   CS0 = AS0
   CS1 = AS1
   CS2 = BS2 if trans_c else AS2
   CS3 = AS2 if trans_c else BS2
   locks = _sparse_matmul.get_locks(2*AS0*AS2//32*num_locks)
   c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
   time[0] = kernel(a, b, c, 
                    a.stride(2), block, c.stride(2), 
                    a.stride(0), b.stride(0), c.stride(0),
                    a.stride(1), b.stride(1), c.stride(1),
                    AS2, 0, 0, lut, locks, num_locks, 
                    grid = lambda opt: [width, triton.cdiv(AS2, opt.d('TM')), AS0], 
                    bench = bench)
   return c
Пример #20
0
 def _call(a, b):
     # create kernel if necessary
     dtype = a.dtype
     if dtype not in _dot.kernel:
         defines = {
             'TYPE': dtype,
             'STRIDE_AM': '1',
             'STRIDE_AK': 'lda',
             'STRIDE_BN': '1',
             'STRIDE_BK': 'ldb',
             'TM': [64, 128],
             'TN': [64, 128],
             'TK': [8, 16],
             'TZ': [1]
         }
         _dot.kernel[dtype] = triton.kernel(_dot.src,
                                            num_warps=[4],
                                            defines=defines)
     kernel = _dot.kernel[dtype]
     # allocate output
     M, K = a.shape
     K, N = b.shape
     c = triton.empty([M, N], dtype=dtype)
     # enqueue
     grid = lambda opt: [
         triton.cdiv(M, opt.d('TM')),
         triton.cdiv(N, opt.d('TN'))
     ]
     time = kernel(a,
                   b,
                   c,
                   1.,
                   M,
                   N,
                   K,
                   a.stride(0),
                   b.stride(0),
                   c.stride(0),
                   grid=grid,
                   bench=100)
     print(2 * M * N * K / (time * 1e-6) * 1e-9)
     return c
Пример #21
0
        def backward(cls, ctx, dneg_logprobs):
            """We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k]
            so we initialize the gradient as neg_logprobs, so we can just exponentiate
            to get p[k], which is most of what we need...  neg_logprobs will be
            modified in place to become the gradient we want
            """
            neg_logprobs, indices = ctx.saved_tensors
            assert indices.dtype == torch.int64
            assert (
                dneg_logprobs.dtype == neg_logprobs.dtype
            ), f"Backward flowing derivatives of type {dneg_logprobs.dtype} != logits type {neg_logprobs.dtype}"
            n_vocab = neg_logprobs.shape[-1]
            N = neg_logprobs.numel()

            if not (neg_logprobs.dtype,
                    n_vocab) in cls.input_config_to_kernel_bwd:
                cls.input_config_to_kernel_bwd[(neg_logprobs.dtype,
                                                n_vocab)] = triton.kernel(
                                                    cls.bwd_src,
                                                    device=neg_logprobs.device,
                                                    defines={
                                                        "TILE": n_vocab,
                                                        "TYPE":
                                                        neg_logprobs.dtype
                                                    },
                                                    num_warps=[4],
                                                )
            kernel_bwd = cls.input_config_to_kernel_bwd[(neg_logprobs.dtype,
                                                         n_vocab)]
            grid = lambda opt: (triton.cdiv(N, opt.TILE), )

            # neg_logprobs will be modified in place to become our gradient:
            kernel_bwd(
                neg_logprobs.data_ptr(),
                indices.data_ptr(),
                dneg_logprobs.data_ptr(),
                grid=grid,
            )

            return neg_logprobs, torch.zeros_like(indices)
Пример #22
0
 def forward(ctx, x, running_mean, running_var, gamma, beta, training,
             momentum, eps):
     N, C, H, W = x.shape
     # lazy compilation of kernel
     key = (training, x.dtype)
     if key not in _batchnorm.fwd_kernel:
         defines = {'TM': 256, 'TYPE': x.dtype}
         if training:
             defines['TRAINING'] = True
         _batchnorm.fwd_kernel[key] = triton.kernel(_batchnorm.fwd_src,
                                                    defines=defines,
                                                    num_warps=[4])
     kernel = _batchnorm.fwd_kernel[key]
     # allocate outputs
     y = torch.empty_strided(x.shape,
                             x.stride(),
                             layout=x.layout,
                             dtype=x.dtype,
                             device=x.device)
     mean = torch.empty(C, dtype=torch.float32, device=x.device)
     var = torch.empty(C, dtype=torch.float32, device=x.device)
     # execute kernels
     grid = lambda opt: [C]
     kernel(y,
            mean,
            var,
            running_mean,
            running_var,
            x,
            gamma,
            beta,
            H * W * N,
            momentum,
            eps,
            grid=grid)
     # save
     ctx.save_for_backward(x, gamma, beta, mean, var)
     ctx.eps = eps
     return y
Пример #23
0
def make_kernel(device, dtype, n_cols, cache, name):
    rounded = next_power_of_2(n_cols)
    div = largest_pow2_divisor(n_cols)
    key = (dtype, rounded, div)
    if key not in cache:
        fname = os.path.join(os.path.dirname(__file__), "cross_entropy.c")
        src = triton.read(fname, kernel_names=[name])
        infinities = {
            torch.float16: "F16_INFINITY",
            torch.float32: "F32_INFINITY",
        }
        defines = {
            "TILE": rounded,
            "TYPE": dtype,
            "INFINITY": infinities[dtype],
            "N_COLS_MULT": div
        }
        cache[key] = triton.kernel(src,
                                   device=device,
                                   defines=defines,
                                   num_warps=4)
    return cache[key]
Пример #24
0
 def backward(ctx, dy):
     # lazy compilation of kernel
     key = (dy.dtype, )
     if key not in _batchnorm.bwd_kernel:
         _batchnorm.bwd_kernel[key] = triton.kernel(_batchnorm.bwd_src,
                                                    defines={
                                                        'TM': 256,
                                                        'TYPE': dy.dtype
                                                    },
                                                    num_warps=[4])
     kernel = _batchnorm.bwd_kernel[key]
     # retrieve info
     x, gamma, beta, mean, var = ctx.saved_tensors
     eps = ctx.eps
     # allocate result
     dx = torch.empty_strided(x.shape,
                              x.stride(),
                              layout=x.layout,
                              dtype=x.dtype,
                              device=x.device)
     dgamma = torch.empty_like(gamma)
     dbeta = torch.empty_like(beta)
     # execute
     N, C, H, W = x.shape
     kernel(dx,
            dgamma,
            dbeta,
            dy,
            x,
            gamma,
            mean,
            var,
            H * W * N,
            eps,
            grid=lambda opt: [C])
     return dx, None, None, dgamma, dbeta, None, None, None
Пример #25
0
 def forward(ctx, a, b, pad, stride):
     # create kernel if necessary
     dtype = a.dtype
     device = a.device
     # shapes
     Z, CI, H, W = a.shape
     _, R, S, CO = b.shape
     P = (H + 2 * pad[0] - R) // stride[0] + 1
     Q = (W + 2 * pad[1] - S) // stride[1] + 1
     # compile kernel
     if (dtype, device) not in _conv.kernel:
         TK = 16
         defines = {
             'TYPE': dtype,
             'TM': [32, 64, 128],
             'TN': [32, 64, 128],
             'TK': [TK],
             'TZ': [1],
             'HH': H,
             'WW': W,
             'PP': P,
             'QQ': Q,
             'SS': S,
             'RR': R,
         }
         idx = torch.arange(CI * R * S)
         ci, r, s = _conv.unpack(idx, CI, R, S)
         nci, nr, ns = _conv.unpack(idx + TK, CI, R, S)
         delta = (nci - ci) * a.stride(1) + (nr - r) * a.stride(2) + (
             ns - s) * a.stride(3)
         delta = delta.type(torch.int32).cuda()
         _conv.kernel[dtype] = (delta,
                                triton.kernel(_conv.src,
                                              device=device,
                                              num_warps=[4],
                                              defines=defines))
     delta, kernel = _conv.kernel[dtype]
     # allocate output
     c = torch.empty([Z, CO, P, Q], dtype=dtype, device=device)
     # enqueue
     kernel(a.data_ptr(),
            b.data_ptr(),
            c.data_ptr(),
            1.,
            Z * P * Q,
            CO,
            CI * R * S,
            pad[0],
            pad[1],
            stride[0],
            stride[1],
            delta.data_ptr(),
            a.stride(0),
            a.stride(1),
            a.stride(2),
            a.stride(3),
            b.stride(0),
            b.stride(1),
            b.stride(2),
            b.stride(3),
            c.stride(0),
            c.stride(1),
            c.stride(2),
            c.stride(3),
            grid=lambda opt:
            [triton.cdiv(Z * P * Q, opt.TM),
             triton.cdiv(CO, opt.TN)])
     return c
Пример #26
0
class _dot(triton.function):

  src = """
void dot(TYPE * A __noalias __readonly __aligned(16),
         TYPE * B __noalias __readonly __aligned(16),
         TYPE * C,
         float alpha,
         int M, int N, int K,
         int lda __multipleof(8),
         int ldb __multipleof(8),
         int ldc) {
    // prologue
      int ridx = get_program_id(0);
      int ridy = get_program_id(1);
      int rm[TM] = ridx * TM + 0 ... TM;
      int rn[TN] = ridy * TN + 0 ... TN;
      int rk[TK] = 0 ... TK;

      // pointers to operands
      TYPE* pa[SHAPE_A] = A + rk[BROADCAST_AK] * STRIDE_AK + rm[BROADCAST_AM] * STRIDE_AM;
      TYPE* pb[SHAPE_B] = B + rk[BROADCAST_BK] * STRIDE_BK + rn[BROADCAST_BN] * STRIDE_BN;

      // prefetches operands
      bool checka[SHAPE_A] = rk[BROADCAST_AK] < K;
      bool checkb[SHAPE_B] = rk[BROADCAST_BK] < K;
      TYPE a[SHAPE_A] = checka ? *pa : 0;
      TYPE b[SHAPE_B] = checkb ? *pb : 0;

      // reduction loop
      float c[TM, TN] = 0;
      for(int k = K; k > 0; k -= TK){
        c += USE_A @ USE_B;
        bool checka[SHAPE_A] = k > TK;
        bool checkb[SHAPE_B] = k > TK;
        pa += TK * STRIDE_AK;
        pb += TK * STRIDE_BK;
        a = *?(checka)pa;
        b = *?(checkb)pb;
      }
      //c = c * alpha;

      // epilogue
      int rxm[TM] = get_program_id(0) * TM + 0 ... TM;
      int rxn[TN] = get_program_id(1) * TN + 0 ... TN;
      TYPE* pc[TM, TN] = C + rxm[:, newaxis] * ldc + rxn[newaxis, :];
      bool checkc[TM, TN] = (rxm[:, newaxis] < M) && (rxn[newaxis, :] < N);
      *?(checkc)pc = (TYPE[TM, TN])c;
}
"""
  kernel = triton.kernel(src, ['C'])

  @staticmethod
  def _call(a, b, transpose_a, transpose_b, bench):
    # extract shapes
    shape_a = triton.shape(a)
    shape_b = triton.shape(b)
    M, Ka = shape_a[0], shape_a[1]
    Kb, N = shape_b[0], shape_b[1]
    # transpose shapes
    if transpose_a:
      M, Ka = Ka, M
    if transpose_b:
      Kb, N = N, Kb
    # contiguous dimensions
    lda = M if transpose_a else Ka
    ldb = Kb if transpose_b else N
    ldc = N
    # data-type
    dtype = a.dtype
    # allocate output
    c = triton.empty([M, N], dtype = dtype)
    # compute
    grid = lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))]
    # macros -- not necessary but makes kernel source-code simpler
    macros = {# handle A transposition
              'USE_A'       : '^a'         if transpose_a else 'a',
              'STRIDE_AK'   : 'lda'        if transpose_a else '1',
              'STRIDE_AM'   : '1'          if transpose_a else 'lda',
              'BROADCAST_AK': ':, newaxis' if transpose_a else 'newaxis, :',
              'BROADCAST_AM': 'newaxis, :' if transpose_a else ':, newaxis',
              'SHAPE_A'     : 'TK, TM'     if transpose_a else 'TM, TK',
              # handle B transposition
              'USE_B'       : '^b'         if transpose_b else 'b',
              'STRIDE_BK'   : '1'          if transpose_b else 'ldb',
              'STRIDE_BN'   : 'ldb'        if transpose_b else '1',
              'BROADCAST_BK': 'newaxis, :' if transpose_b else ':, newaxis',
              'BROADCAST_BN': ':, newaxis' if transpose_b else 'newaxis, :',
              'SHAPE_B'     : 'TN, TK'     if transpose_b else 'TK, TN'}
    _dot.kernel(a, b, c, 1., M, N, Ka, lda, ldb, ldc, 
                grid, bench=bench,           
                AT = transpose_a, BT = transpose_b, TYPE = dtype, 
                TM = [64], TN = [128], TK = [8], **macros)
    return c

  @staticmethod
  def forward(ctx, a, b, transpose_a = False, transpose_b = False, bench = 0):
    ctx.save_for_backward(a, b)
    ctx.t_a = transpose_a
    ctx.t_b = transpose_b
    ctx.bench = bench
    return _dot._call(a, b, transpose_a, transpose_b, bench)

  @staticmethod
  def backward(ctx, dy):
    a, b = ctx.saved_tensors
    t_a, t_b = ctx.t_a, ctx.t_b
    bench = ctx.bench
    if not t_a and not t_b:
      da = _dot._call(dy, b, False, True, bench)
      db = _dot._call(a, dy, True, False, bench)
    elif not t_a and t_b:
      da = _dot._call(dy, b, False, False, bench)
      db = _dot._call(dy, a, True, False, bench)
    elif t_a and not t_b:
      da = _dot._call(b, dy, False, True, bench)
      db = _dot._call(a, dy, False, False, bench)
    elif t_a and t_b:
      da = _dot._call(b, dy, True, True, bench)
      db = _dot._call(dy, a, True, True, bench)
    else:
      assert False
    return da, db, None, None, None
Пример #27
0
 def _call(a, b):
     dtype = a.dtype
     device = a.device
     # allocate output
     M, K = a.shape
     K, N = b.shape
     c = torch.empty((M, N), dtype=dtype, device=device)
     # handle non-contiguous inputs if necessary
     if a.stride(0) > 1 and a.stride(1) > 1:
         a = a.contiguous()
     if b.stride(0) > 1 and b.stride(1) > 1:
         b = b.contiguous()
     # kernel hash
     is_a_row = a.stride(1) == 1
     is_b_row = b.stride(1) == 1
     lda = a.stride(0) if is_a_row else a.stride(1)
     ldb = b.stride(0) if is_b_row else b.stride(1)
     ldc = c.stride(0)
     lda_pow2_div = _matmul.largest_pow2_divisor(lda)
     ldb_pow2_div = _matmul.largest_pow2_divisor(ldb)
     ldc_pow2_div = _matmul.largest_pow2_divisor(ldc)
     is_tk_div_k = K % 64 == 0
     key = (
         device,
         dtype,
         is_a_row,
         is_b_row,
         lda_pow2_div,
         ldb_pow2_div,
         ldc_pow2_div,
         is_tk_div_k,
     )
     if key not in _matmul._kernels:
         defines = {
             "TYPE": dtype,
             "STRIDE_AM": "lda" if is_a_row else "1",
             "STRIDE_AK": "1" if is_a_row else "lda",
             "STRIDE_BK": "ldb" if is_b_row else "1",
             "STRIDE_BN": "1" if is_b_row else "ldb",
             "LDA_POW2_DIV": lda_pow2_div,
             "LDB_POW2_DIV": ldb_pow2_div,
             "LDC_POW2_DIV": ldc_pow2_div,
             "IS_TK_DIV_K": int(is_tk_div_k),
         }
         _matmul._kernels[key] = triton.kernel(
             _matmul.src,
             device,
             defines=defines,
             autotune_vals=_matmul._CONFIGS,
             autotune_key=["M", "N", "K"],
         )
     kernel = _matmul._kernels[key]
     # # locks for split-k
     if device not in _matmul._locks:
         _matmul._locks[device] = torch.zeros(1024 * 1024,
                                              dtype=torch.int32,
                                              device=device)
     locks = _matmul._locks[device]
     # enqueue
     alpha = 1.0
     args = [
         a.data_ptr(),
         b.data_ptr(),
         c.data_ptr(),
         alpha,
         M,
         N,
         K,
         lda,
         ldb,
         ldc,
         locks.data_ptr(),
     ]
     grid = lambda opt: [
         triton.cdiv(M, opt.TM) * triton.cdiv(N, opt.TN),
         1,
         opt.SPLITK,
     ]
     kernel(*args, grid=grid)
     return c
Пример #28
0
    def make_kernel(name, dtype, mask, expr_a, expr_b, expr_c, axes_m, axes_n,
                    axes_k, axes_b, multipleof_a, multipleof_b, multipleof_c,
                    stride_a_last, stride_b_last, stride_c_last, lut_mode_a,
                    lut_mode_b, delta_a, delta_b, subscripted, varnames):

        use_lut_a = True
        use_lut_b = True

        src = ""

        if use_lut_a and lut_mode_a == _einsum.LUT_MODE.CONSTANT:
            src += f"""
char __constant__* AD = calloc({4*len(delta_a)});"""
        if use_lut_b and lut_mode_b == _einsum.LUT_MODE.CONSTANT:
            src += f"""
char __constant__* BD = calloc({4*len(delta_b)});"""

        src += f"""
__global__ void {name}(
              TYPE * A __noalias __readonly __aligned(16)
            , TYPE * B __noalias __readonly __aligned(16)
            , TYPE * C
            , int * locks
            , float alpha
            , int matmul_m, int matmul_n, int matmul_k __multipleof(16)
            , int div_m
            """
        for dim in [axes_m, axes_n, axes_k, axes_b]:
            for d in dim:
                src += f", int dim_{d}"
        src += "\n            "
        for dim, name, mult in zip([expr_a, expr_b, expr_c], ['a', 'b', 'c'],
                                   [multipleof_a, multipleof_b, multipleof_c]):
            for d in range(len(dim) - 1):
                attr = f'__multipleof({mult})'
                src += f", int stride_{name}_{d} {attr}"
            src += "\n            "
        if lut_mode_a == _einsum.LUT_MODE.SCALAR:
            src += f", int stride_a_inner __multipleof({multipleof_a})"
            src += f", int rem_delta_a __multipleof({multipleof_a})"
        elif lut_mode_a == _einsum.LUT_MODE.DRAM:
            src += ", int* AD __noalias __readonly __aligned(16)"
        src += "\n            "
        if lut_mode_b == _einsum.LUT_MODE.SCALAR:
            src += f", int stride_b_inner __multipleof({multipleof_b})"
            src += f", int rem_delta_b __multipleof({multipleof_b})"
        elif lut_mode_b == _einsum.LUT_MODE.DRAM:
            src += ", int* BD"
        src += "\n"
        for ptr in subscripted:
            src += f", int* {ptr}"
        for name in varnames:
            src += f", int {name}"
        src += """) {

    // re-order outer program ids
    int grid_m = (matmul_m + TM - 1) / TM;
    int grid_n = (matmul_n + TN - 1) / TN;
    int pid_mn = get_program_id(0) / div_m;
    int pid_n = pid_mn % grid_n;
    int pid_m = (pid_mn / grid_n)*div_m + (get_program_id(0) % div_m);

    // get batch program id
    int pid_b = get_program_id(1);

#if TZ == 1
    int off_k = 0;
#else
    // get reduction sub-group program id
    int pid_z = get_program_id(2);
    int grid_z = get_num_programs(2);
    int div_z = matmul_k / TZ;
    int rem_z = matmul_k % TZ;
    int off_k = pid_z * div_z;
    matmul_k = select(pid_z < rem_z, div_z, div_z + rem_z);
#endif
    int rem_k = matmul_k % TK;
    
    // create ranges
"""
        rk = 'r{}'.format(''.join(map(str, axes_k)))
        for axes, tile, off in zip(
            [axes_m, axes_n, axes_b, axes_k], ['TM', 'TN', 'TB', 'TK'],
            ['pid_m*TM', 'pid_n*TN', 'pid_b*TB', 'off_k']):
            currs = ''.join(map(str, axes))
            if axes:
                src += f"    int r{currs}[{tile}] = {off} + 0 ... {tile};\n"
                src += _einsum.unpack_cc(tile, axes, 'r', False)

        src += """    
    // initialize pointers to A
    int offa[TM, TK, TB] = """
        for i, sym in enumerate(expr_a):
            ccode = _einsum.print_cc(sym, axes_m, axes_k, axes_b)
            stride = f'stride_a_{i}' if i < len(
                expr_a) - 1 else f'{stride_a_last}'
            if i > 0:
                src += ' + '
            src += f"({ccode}) * {stride}\n                            "
        src += ';'

        src += """
    TYPE *pa[TM, TK, TB] = A + offa;"""

        if use_lut_a and not lut_mode_a == _einsum.LUT_MODE.SCALAR:
            spec = '__constant__' if lut_mode_a == _einsum.LUT_MODE.CONSTANT else ''
            cast = '(int __constant__*)' if lut_mode_a == _einsum.LUT_MODE.CONSTANT else ''
            src += f"""
    // initialize pointers to A look-up table
    int offadelta[TK] = off_k + 0 ... TK;
    int {spec} *padelta[TK]  = {cast}AD  + offadelta;
    int incda[TM, TK, TB] = (*padelta)[newaxis, :, newaxis];"""

        src += """

    // initialize pointers to B
    int offb[TK, TN, TB] = """
        for i, sym in enumerate(expr_b):
            ccode = _einsum.print_cc(sym, axes_k, axes_n, axes_b)
            stride = f'stride_b_{i}' if i < len(
                expr_b) - 1 else f'{stride_b_last}'
            if i > 0:
                src += ' + '
            src += f"({ccode}) * {stride}\n                            "
        src += ';'

        src += """
    TYPE *pb[TK, TN, TB] = B + offb;"""

        if use_lut_b and not lut_mode_b == _einsum.LUT_MODE.SCALAR:
            spec = '__constant__' if lut_mode_b == _einsum.LUT_MODE.CONSTANT else ''
            cast = '(int __constant__*)' if lut_mode_b == _einsum.LUT_MODE.CONSTANT else ''
            src += f"""
    // initialize pointers to B look-up table
    int offbdelta[TK] = off_k + 0 ... TK;
    int *pbdelta[TK]  = BD  + offbdelta;"""

        src += f"""

    // prefetch 
    int prefetch_k = select(rem_k > 0, rem_k, TK);
    bool checkm[TM] = r""" + ''.join(map(str, axes_m)) + f""" < matmul_m;
    bool checkn[TN] = r""" + ''.join(map(str, axes_n)) + f""" < matmul_n;
    bool checkk[TK] = {rk} < prefetch_k;
    bool checka[TM, TK, TB] = checkm[:, newaxis, newaxis] && checkk[newaxis, :, newaxis];
    bool checkb[TK, TN, TB] = checkk[:, newaxis, newaxis] && checkn[newaxis, :, newaxis];
    TYPE a[TM, TK, TB] = checka ? *pa : 0;
    TYPE b[TK, TN, TB] = checkb ? *pb : 0;"""

        if lut_mode_a == _einsum.LUT_MODE.SCALAR:
            src += """
    pa += rem_delta_a;"""
        else:
            src += """
    pa += incda;
    padelta += TK;
    incda = (*padelta)[newaxis, :, newaxis];"""

        if lut_mode_b == _einsum.LUT_MODE.SCALAR:
            src += """
    pb += rem_delta_b;"""
        else:
            src += """
    pb += (*pbdelta)[:, newaxis, newaxis];
    pbdelta += TK;"""

        src += f"""
    // accumulate
    float acc[TM, TN, TB] = 0;
    for(int k = matmul_k; k > 0; k -= TK) {{
        acc += a @ b;
        #ifdef MASK
        uint32 bits[TM, TN, TB] = bitcast<uint32[TM,TN,TB]>(acc);
        acc = bitcast<float[TM, TN, TB]>(bits & MASK);
        #endif

        checkk = k > TK;
        checka = checkm[:, newaxis, newaxis] && checkk[newaxis, :, newaxis];
        checkb = checkk[:, newaxis, newaxis] && checkn[newaxis, :, newaxis];
        a = *?(checka)pa;
        b = *?(checkb)pb;"""

        if lut_mode_a == _einsum.LUT_MODE.SCALAR:
            src += """
        pa += stride_a_inner;"""
        else:
            src += """
        pa += incda;
        padelta += TK;
        incda = (*padelta)[newaxis, :, newaxis];"""

        if lut_mode_b == _einsum.LUT_MODE.SCALAR:
            src += """
        pb += stride_b_inner;"""
        else:
            src += """
        pb += (*pbdelta)[:, newaxis, newaxis];
        pbdelta += TK;"""

        src += f"""
    }}
    TYPE c[TM, TN, TB] = acc;

    // re-materialize ranges
    pid_mn = get_program_id(0) / div_m;
    pid_n = pid_mn % grid_n;
    pid_m = (pid_mn / grid_n)*div_m + (get_program_id(0) % div_m);
"""
        for axes, tile, off in zip([axes_m, axes_n, axes_b],
                                   ['TM', 'TN', 'TB'],
                                   ['pid_m*TM', 'pid_n*TN', 'pid_b*TB']):
            currs = ''.join(map(str, axes))
            if axes:
                src += f"    r{currs} = {off} + 0 ... {tile};\n"
                src += _einsum.unpack_cc(tile, axes, 'r', True)

        src += """
    // initialize pointers to C
    int offc[TM, TN, TB] = """
        for i, sym in enumerate(expr_c):
            stride = f'stride_c_{i}' if i < len(
                expr_c) - 1 else f'{stride_c_last}'
            ccode = _einsum.print_cc(sym, axes_m, axes_n, axes_b)
            if i > 0:
                src += ' + '
            src += f"({ccode}) * {stride}\n                            "
        src += ';'

        src += """
    TYPE *pc[TM, TN, TB] = C + offc;
    
    // bounds-checking
    checkm = r""" + ''.join(map(str, axes_m)) + """ < matmul_m;
    checkn = r""" + ''.join(map(str, axes_n)) + """ < matmul_n;
    bool checkc[TM, TN, TB] = checkm[:, newaxis, newaxis] && 
                              checkn[newaxis, :, newaxis];

    // write back
#if TZ == 1
    *?(checkc)pc = c;
#else
    int *plock = locks + pid_mn + pid_b * get_num_programs(0);
    int *pcount = plock + 1024*1024;
    // spin
    for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
    int count = *pcount;
    if(count == 0)
      *?(checkc)pc = c;
    else
      *?(checkc)pc = c + *?(checkc)pc;
    atomic_xchg(pcount, (count + 1) % (grid_z));
    atomic_xchg(plock, 0);
#endif
}
"""
        # compilation options
        TM, TN, TB, TZ = [16, 32, 64, 128], [16, 32, 64, 128], 1, [1, 4, 16]
        TK = 16 if dtype == torch.float16 else 8
        defines = {
            'TM': TM,
            'TN': TN,
            'TB': TB,
            'TK': TK,
            'TZ': TZ,
            'TYPE': dtype
        }
        if mask is not None:
            defines['MASK'] = '{0:#0{1}x}'.format(mask, 10)
        # create kernel
        ret = triton.kernel(src, defines=defines)
        # set constant
        if use_lut_a and lut_mode_a == _einsum.LUT_MODE.CONSTANT:
            ret.set_constant('AD', delta_a)
        if use_lut_b and lut_mode_b == _einsum.LUT_MODE.CONSTANT:
            ret.set_constant('BD', delta_b)
        return ret
Пример #29
0
    def _sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, luts,
                    num_locks, widths, packs, bench, time):
        if trans_c:
            a, b = b, a
            trans_a, trans_b = not trans_b, not trans_a
        AS0 = a.size(0)
        AS1 = a.size(1)
        AS2 = a.size(3 if trans_a else 2)
        AS3 = a.size(2 if trans_a else 3)
        BS0 = b.size(0)
        BS1 = b.size(1)
        BS2 = b.size(3 if trans_b else 2)
        BS3 = b.size(2 if trans_b else 3)
        dtype = a.dtype
        if dtype == torch.float16 and AS3 % 64 > 0:
            raise ValueError(
                'Reduction size for SDD must be a multiple of 64 in FLOAT16')
        if dtype == torch.float32 and AS3 % 16 > 0:
            raise ValueError(
                'Reduction size for SDD must be a multiple of 16 in FLOAT32')
        # create kernel
        total_width = sum(
            [width * pack * pack for width, pack in zip(widths, packs)])
        c = torch.empty((AS0, total_width, block, block),
                        dtype=dtype,
                        device=a.device)
        for lut, width, pack in zip(luts, widths, packs):
            num_lock = 1
            key = (block, a.dtype, b.dtype, trans_a, trans_b, trans_c, pack)
            if key not in _sparse_matmul.sdd_cache:
                TK = {
                    torch.float32: [8, 16],
                    torch.float16: [16, 32, 64]
                }[dtype]
                defines = {
                    'TM': block * pack,
                    'TN': block * pack,
                    'TMN': block * block * pack * pack,
                    'BLOCK': block,
                    'TK': TK,
                    'TYPE': dtype,
                    'STRIDE_AM': '1' if trans_a else 'lda',
                    'STRIDE_AK': 'lda' if trans_a else '1',
                    'STRIDE_BN': 'ldb' if trans_b else '1',
                    'STRIDE_BK': '1' if trans_b else 'ldb',
                    'STRIDE_CM': 'ldc',
                    'STRIDE_CN': '1',
                    'SDD': True,
                    'TZ': 1,
                    'NAME': 'sdd_kernel'
                }
                _sparse_matmul.sdd_cache[key] = triton.kernel(
                    src, defines=defines, num_warps=[1, 2, 4])

            kernel = _sparse_matmul.sdd_cache[key]
            # create output
            locks = _sparse_matmul.get_locks(2 * width * AS0 * num_lock,
                                             a.device)
            # maximum grid size is 65535
            # so operation might be decomposed into multiple
            # kernel calls
            max_width = 49152
            total = 0 if bench else None
            for off_width in range(0, width, max_width):
                current = kernel(
                    a,
                    b,
                    c,
                    a.stride(2),
                    b.stride(2),
                    block,
                    a.stride(0),
                    b.stride(0),
                    c.stride(0),
                    a.stride(1),
                    b.stride(1),
                    c.stride(0),
                    AS2,
                    AS2,
                    AS3,
                    off_width,
                    lut,
                    locks,
                    num_lock,
                    grid=lambda opt:
                    [opt.d('TZ'),
                     min(max_width, width - off_width), AS0],
                    bench=bench)
                total = total + current if bench else None
            time[0] = total
        # save for backward pass
        return c
Пример #30
0
class _batchnorm(triton.function):

  fwd_src = """
void fwdbatchnorm(float *Y, float *M, float *V,
                  float *X, float *G, float *B,
                  int N, float eps) {
  // pointers
  int c = get_program_id(1);
  int rm[TM] = 0 ... TM;
  float *px[TM] = X + rm + c*N;
  float* py[TM] = Y + rm + c*N;

  // compute mean
  float accm[TM] = 0;
  for(int i = 0; i < N; i = i + TM)
    accm = accm + *(px + i);
  float mean = (float)accm[+] / N;
  *(M + c) = mean;

  // compute variance
  float accv[TM] = 0;
  for(int i = 0; i < N; i = i + TM){
    float x[TM] = *(px + i);
    x = x - mean;
    accv = accv + x*x;
  }
  float var = (float)accv[+] / N;
  *(V + c) = var;

  // Normalize batch
  float gamma = *(G + c);
  float beta = *(B + c);
  float rstdg = 1 / sqrtf(var + eps) * gamma;
  for(int i = 0; i < N; i = i + TM){
    float x[TM] = *(px + i);
    float y[TM] = (x - mean)*rstdg + beta;
    *(py + i) = y;
  }
}
"""
  fwd_kernel = triton.kernel(fwd_src, ['Y', 'M', 'V'])

  bwd_src = """
void bwdbatchnorm(float *DX, float *DG, float *DB,
                  float *DY, float *X, float *G,
                  float *M, float *V,
                  int N, float epsilon) {

   // pointers
  int c = get_program_id(1);
  int rx[TM] = 0 ... TM;
  int offset = c*N;
  float* px[TM]  =  X + rx + offset;
  float* pdy[TM] = DY + rx + offset;
  float* pdx[TM] = DX + rx + offset;

  // fetch statistics
  float gamma = *(G + c);
  float mean = *(M + c);
  float var = *(V + c);
  float rstd = 1 / sqrtf(var + epsilon);

  // compute dgamma and dbeta
  float  acc_dg[TM] = 0;
  float  acc_db[TM] = 0;
  for(int i = 0; i < N; i = i + TM){
    float x[TM] = *(px + i);
    float dy[TM] = *(pdy + i);
    acc_dg += dy*(x - mean)*rstd;
    acc_db += dy;
  }
  float dg = acc_dg[+];
  float db = acc_db[+];
  *(DG + c) = dg;
  *(DB + c) = db;

  // compute dx
  for(int i = 0; i < N; i = i + TM){
    float x[TM] = *(px + i);
    float dy[TM] = *(pdy + i);
    float xhat[TM] = (x - mean) * rstd;
    float xtmp[TM] = (xhat * dg + db) / N;
    float dx[TM] = (dy - xtmp) * rstd * gamma;
    *(pdx + i) = dx;
  }
}
"""
  bwd_kernel = triton.kernel(bwd_src, ['DX', 'DG', 'DB'])

  @staticmethod
  def forward(ctx, x, gamma, beta, eps):
    shape = triton.shape(x)
    dtype = x.dtype
    # allocate outputs
    C, H, W, B = shape[0], shape[1], shape[2], shape[3]
    y = triton.empty(shape, dtype=dtype)
    mean = triton.empty([C], dtype=dtype)
    var = triton.empty([C], dtype=dtype)
    # execute kernels
    _batchnorm.fwd_kernel(y, mean, var, x, gamma, beta, H*W*B, eps,
                          grid = lambda opt: [1, C],
                          defines = {'TM': 128})
    # save
    ctx.save_for_backward(x, gamma, beta, mean, var)
    ctx.eps = eps
    return y

  @staticmethod
  def backward(ctx, dy):
    # retrieve info
    x, gamma, beta, mean, var = ctx.saved_tensors
    eps = ctx.eps
    # allocate result
    dx = triton.empty(triton.shape(x), dtype=x.dtype)
    dgamma = triton.empty(triton.shape(gamma), dtype=gamma.dtype)
    dbeta = triton.empty(triton.shape(beta), dtype=beta.dtype)
    # execute
    C, H, W, B = triton.shape(x)
    _batchnorm.bwd_kernel(dx, dgamma, dbeta, dy, 
                          x, gamma, mean, var, 
                          H*W*B, eps,
                          grid = lambda opt: [1, C],
                          defines = {'TM': 128})
    return dx, dgamma, dbeta, None