コード例 #1
0
def get_kernel(block, dtype, device):
    key = (block, dtype, device)
    if key not in kernels:
        src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'))
        defines = {'BLOCK': block, 'TYPE': dtype}
        kernels[key] = triton.kernel(src, device=device, defines=defines)
    return kernels[key]
コード例 #2
0
ファイル: cross_entropy.py プロジェクト: nottombrown/triton
def make_kernel(device, dtype, n_cols, cache, name):
    rounded = next_power_of_2(n_cols)
    div = largest_pow2_divisor(n_cols)
    key = (dtype, rounded, div)
    if key not in cache:
        fname = os.path.join(os.path.dirname(__file__), "cross_entropy.c")
        src = triton.read(fname, kernel_names=[name])
        infinities = {
            torch.float16: "F16_INFINITY",
            torch.float32: "F32_INFINITY",
        }
        defines = {
            "TILE": rounded,
            "TYPE": dtype,
            "INFINITY": infinities[dtype],
            "N_COLS_MULT": div
        }
        cache[key] = triton.kernel(src,
                                   device=device,
                                   defines=defines,
                                   num_warps=4)
    return cache[key]
コード例 #3
0
ファイル: conv.py プロジェクト: daadaada/triton
class _conv(torch.autograd.Function):
    src = triton.read(os.path.join(os.path.dirname(__file__), 'conv.c'))
    kernel = dict()

    @staticmethod
    def unpack(IDX, CI, R, S):
        s = IDX % S
        cr = IDX // S
        r = cr % R
        ci = cr // R
        return ci, r, s

    @staticmethod
    def forward(ctx, a, b, pad, stride):
        # create kernel if necessary
        dtype = a.dtype
        device = a.device
        # shapes
        Z, CI, H, W = a.shape
        _, R, S, CO = b.shape
        P = (H + 2 * pad[0] - R) // stride[0] + 1
        Q = (W + 2 * pad[1] - S) // stride[1] + 1
        # compile kernel
        if (dtype, device) not in _conv.kernel:
            TK = 16
            defines = {
                'TYPE': dtype,
                'TM': [32, 64, 128],
                'TN': [32, 64, 128],
                'TK': [TK],
                'TZ': [1],
                'HH': H,
                'WW': W,
                'PP': P,
                'QQ': Q,
                'SS': S,
                'RR': R,
            }
            idx = torch.arange(CI * R * S)
            ci, r, s = _conv.unpack(idx, CI, R, S)
            nci, nr, ns = _conv.unpack(idx + TK, CI, R, S)
            delta = (nci - ci) * a.stride(1) + (nr - r) * a.stride(2) + (
                ns - s) * a.stride(3)
            delta = delta.type(torch.int32).cuda()
            _conv.kernel[dtype] = (delta,
                                   triton.kernel(_conv.src,
                                                 device=device,
                                                 num_warps=[4],
                                                 defines=defines))
        delta, kernel = _conv.kernel[dtype]
        # allocate output
        c = torch.empty([Z, CO, P, Q], dtype=dtype, device=device)
        # enqueue
        kernel(a.data_ptr(),
               b.data_ptr(),
               c.data_ptr(),
               1.,
               Z * P * Q,
               CO,
               CI * R * S,
               pad[0],
               pad[1],
               stride[0],
               stride[1],
               delta.data_ptr(),
               a.stride(0),
               a.stride(1),
               a.stride(2),
               a.stride(3),
               b.stride(0),
               b.stride(1),
               b.stride(2),
               b.stride(3),
               c.stride(0),
               c.stride(1),
               c.stride(2),
               c.stride(3),
               grid=lambda opt:
               [triton.cdiv(Z * P * Q, opt.TM),
                triton.cdiv(CO, opt.TN)])
        return c
コード例 #4
0
    class _softmax_xent_loss(torch.autograd.Function):
        """This modifies logits in place, turning them into negative logprobs
        on the forward pass.  It should not copy the logits at all.
        """

        fwd_src = triton.read(
            os.path.join(
                os.path.dirname(__file__),
                "softmax_xent_kernels.c",
            ),
            kernel_names=["softmax_fwd"],
        )
        bwd_src = triton.read(
            os.path.join(
                os.path.dirname(__file__),
                "softmax_xent_kernels.c",
            ),
            kernel_names=["softmax_bwd"],
        )

        # Need TILE = n_vocab for this approach to work:
        input_config_to_kernel_fwd = {}
        input_config_to_kernel_bwd = {}

        @classmethod
        def forward(cls, ctx, logits, indices):
            # make sure we can use triton
            n_vocab = logits.shape[-1]
            assert (indices.dtype == torch.int64
                    ), "Indices are expected to be of type long."

            # compile a new kernel if needed; otherwise load from a cache
            if not (logits.dtype, n_vocab) in cls.input_config_to_kernel_fwd:
                infinities = {
                    torch.float16: "F16_INFINITY",
                    torch.float32: "F32_INFINITY",
                }
                cls.input_config_to_kernel_fwd[(
                    logits.dtype, n_vocab)] = triton.kernel(
                        cls.fwd_src,
                        device=logits.device,
                        defines={
                            "TILE": make_power_of_two(n_vocab),
                            "TYPE": logits.dtype,
                            "INFINITY": infinities[logits.dtype],
                        },
                    )
            kernel_fwd = cls.input_config_to_kernel_fwd[(logits.dtype,
                                                         n_vocab)]

            # flatten logits and be prepared to restore them to their original shape
            original_logits_shape = logits.shape
            if len(original_logits_shape) > 2:
                logits = logits.reshape((-1, n_vocab))
                indices = indices.reshape((-1, ))

            # run the kernel and assign the result in place
            result = torch.empty_like(indices, dtype=logits.dtype).cuda()
            neg_logprobs = torch.empty_like(logits, dtype=logits.dtype).cuda()
            grid = lambda opt: (logits.shape[0], )
            kernel_fwd(
                logits.data_ptr(),
                neg_logprobs.data_ptr(),
                indices.data_ptr(),
                result.data_ptr(),
                n_vocab,
                grid=grid,
            )

            if len(original_logits_shape) > 2:
                logits = logits.reshape(original_logits_shape)
                indices = indices.reshape(*original_logits_shape[:-1])

            ctx.save_for_backward(neg_logprobs, indices)
            ctx.original_logits_shape = original_logits_shape

            return result

        @classmethod
        def backward(cls, ctx, dneg_logprobs):
            """We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k]
            so we initialize the gradient as neg_logprobs, so we can just exponentiate
            to get p[k], which is most of what we need...  neg_logprobs will be
            modified in place to become the gradient we want
            """
            # load saved tensors and ensure correct types
            neg_logprobs, indices = ctx.saved_tensors
            original_logits_shape = ctx.original_logits_shape
            assert (
                dneg_logprobs.dtype == neg_logprobs.dtype
            ), f"Backward flowing derivatives of type {dneg_logprobs.dtype} != logits type {neg_logprobs.dtype}"
            n_vocab = neg_logprobs.shape[-1]

            # generate or load kernel
            if not (neg_logprobs.dtype,
                    n_vocab) in cls.input_config_to_kernel_bwd:
                cls.input_config_to_kernel_bwd[(
                    neg_logprobs.dtype, n_vocab)] = triton.kernel(
                        cls.bwd_src,
                        device=neg_logprobs.device,
                        defines={
                            "TILE": make_power_of_two(n_vocab),
                            "TYPE": neg_logprobs.dtype,
                        },
                    )
            kernel_bwd = cls.input_config_to_kernel_bwd[(neg_logprobs.dtype,
                                                         n_vocab)]
            grid = lambda opt: (neg_logprobs.shape[0], )

            # neg_logprobs will be modified in place to become our gradient:
            kernel_bwd(
                neg_logprobs.data_ptr(),
                indices.data_ptr(),
                dneg_logprobs.data_ptr(),
                n_vocab,
                grid=grid,
            )

            # reshape results based on shape of original logits passed to forward
            if len(original_logits_shape) > 2:
                neg_logprobs = neg_logprobs.reshape(original_logits_shape)

            return neg_logprobs, torch.zeros_like(indices)
コード例 #5
0
ファイル: matmul.py プロジェクト: jareddk/triton
import triton
import triton._C.libtriton as libtriton
import torch
import os
import math

src = triton.read(os.path.join(os.path.dirname(__file__), 'matmul.c'))

##############
#  MAIN API  #
##############
class _matmul(torch.autograd.Function):
  
  sdd_cache = dict()
  dsd_cache = dict()
  dds_cache = dict()
  locks = dict()

  # Given an array sizes representing reduction size for each
  # column of a block-mode matrix multiplication,
  # performs load-balancing to achieve more smaller reductions
  # between `seg_size` elements
  @staticmethod
  def load_balance(sizes, block):
    # segment size
    # heuristics taken from OpenAI blocksparse code
    # https://github.com/openai/blocksparse/blob/master/blocksparse/matmul.py#L95
    max_size = sizes.max()
    min_size = sizes[sizes != 0].min()
    #if max_size > min_size * 2.0:
    #  seg_max = max(triton.cdiv(max_size, 4), min_size*2)
コード例 #6
0
class _matmul(torch.autograd.Function):
    src = triton.read(os.path.join(os.path.dirname(__file__), "matmul.c"))

    _DEFAULT_CONFIGS = [
        triton.config(defines={
            "TM": "128",
            "TN": "128",
            "TK": "32",
            "SPLITK": "1"
        },
                      num_warps=4),
        triton.config(defines={
            'TM': '64',
            'TN': '128',
            'TK': '32',
            'SPLITK': '1'
        },
                      num_warps=4),
        triton.config(defines={
            'TM': '128',
            'TN': '64',
            'TK': '32',
            'SPLITK': '1'
        },
                      num_warps=4),
        triton.config(defines={
            'TM': '64',
            'TN': '64',
            'TK': '64',
            'SPLITK': '1'
        },
                      num_warps=4),
        triton.config(defines={
            'TM': '32',
            'TN': '128',
            'TK': '64',
            'SPLITK': '1'
        },
                      num_warps=4),
        triton.config(defines={
            'TM': '128',
            'TN': '32',
            'TK': '64',
            'SPLITK': '1'
        },
                      num_warps=4),
        triton.config(defines={
            'TM': '64',
            'TN': '32',
            'TK': '64',
            'SPLITK': '1'
        },
                      num_warps=2),
        triton.config(defines={
            'TM': '32',
            'TN': '64',
            'TK': '64',
            'SPLITK': '1'
        },
                      num_warps=2),
        triton.config(defines={
            'TM': '32',
            'TN': '128',
            'TK': '32',
            'SPLITK': '2'
        },
                      num_warps=4),
        triton.config(defines={
            'TM': '32',
            'TN': '128',
            'TK': '32',
            'SPLITK': '2'
        },
                      num_warps=4),
        triton.config(defines={
            'TM': '128',
            'TN': '32',
            'TK': '32',
            'SPLITK': '4'
        },
                      num_warps=4),
        triton.config(defines={
            'TM': '128',
            'TN': '32',
            'TK': '32',
            'SPLITK': '4'
        },
                      num_warps=4),
    ]
    _CONFIGS = _DEFAULT_CONFIGS

    @staticmethod
    def largest_pow2_divisor(N):
        if N % 8 == 0:
            return 8
        if N % 4 == 0:
            return 4
        if N % 2 == 0:
            return 2
        return 1

    _locks = dict()
    _kernels = dict()

    @staticmethod
    def _call(a, b):
        dtype = a.dtype
        device = a.device
        # allocate output
        M, K = a.shape
        K, N = b.shape
        c = torch.empty((M, N), dtype=dtype, device=device)
        # handle non-contiguous inputs if necessary
        if a.stride(0) > 1 and a.stride(1) > 1:
            a = a.contiguous()
        if b.stride(0) > 1 and b.stride(1) > 1:
            b = b.contiguous()
        # kernel hash
        is_a_row = a.stride(1) == 1
        is_b_row = b.stride(1) == 1
        lda = a.stride(0) if is_a_row else a.stride(1)
        ldb = b.stride(0) if is_b_row else b.stride(1)
        ldc = c.stride(0)
        lda_pow2_div = _matmul.largest_pow2_divisor(lda)
        ldb_pow2_div = _matmul.largest_pow2_divisor(ldb)
        ldc_pow2_div = _matmul.largest_pow2_divisor(ldc)
        is_tk_div_k = K % 64 == 0
        key = (
            device,
            dtype,
            is_a_row,
            is_b_row,
            lda_pow2_div,
            ldb_pow2_div,
            ldc_pow2_div,
            is_tk_div_k,
        )
        if key not in _matmul._kernels:
            defines = {
                "TYPE": dtype,
                "STRIDE_AM": "lda" if is_a_row else "1",
                "STRIDE_AK": "1" if is_a_row else "lda",
                "STRIDE_BK": "ldb" if is_b_row else "1",
                "STRIDE_BN": "1" if is_b_row else "ldb",
                "LDA_POW2_DIV": lda_pow2_div,
                "LDB_POW2_DIV": ldb_pow2_div,
                "LDC_POW2_DIV": ldc_pow2_div,
                "IS_TK_DIV_K": int(is_tk_div_k),
            }
            _matmul._kernels[key] = triton.kernel(
                _matmul.src,
                device,
                defines=defines,
                autotune_vals=_matmul._CONFIGS,
                autotune_key=["M", "N", "K"],
            )
        kernel = _matmul._kernels[key]
        # # locks for split-k
        if device not in _matmul._locks:
            _matmul._locks[device] = torch.zeros(1024 * 1024,
                                                 dtype=torch.int32,
                                                 device=device)
        locks = _matmul._locks[device]
        # enqueue
        alpha = 1.0
        args = [
            a.data_ptr(),
            b.data_ptr(),
            c.data_ptr(),
            alpha,
            M,
            N,
            K,
            lda,
            ldb,
            ldc,
            locks.data_ptr(),
        ]
        grid = lambda opt: [
            triton.cdiv(M, opt.TM) * triton.cdiv(N, opt.TN),
            1,
            opt.SPLITK,
        ]
        kernel(*args, grid=grid)
        return c

    @staticmethod
    def forward(ctx, a, b):
        c = _matmul._call(a, b)
        return c
コード例 #7
0
ファイル: matmul.py プロジェクト: daadaada/triton
class _matmul(torch.autograd.Function):
    src = triton.read(os.path.join(os.path.dirname(__file__), 'matmul.c'))

    TM = [128]
    TN = [128]
    TK = [32]
    TZ = 1
    num_warps = [4]

    @staticmethod
    def largest_pow2_divisor(N):
        if N % 8 == 0: return 8
        if N % 4 == 0: return 4
        if N % 2 == 0: return 2
        return 1

    _locks = dict()
    _kernels = dict()

    @staticmethod
    def _call(a, b):
        dtype = a.dtype
        device = a.device
        # allocate output
        M, K = a.shape
        K, N = b.shape
        c = torch.empty((M, N), dtype=dtype, device=device)
        # handle non-contiguous inputs if necessary
        if a.stride(0) > 1 and a.stride(1) > 1: a = a.contiguous()
        if b.stride(0) > 1 and b.stride(1) > 1: b = b.contiguous()
        # kernel hash
        is_a_row = a.stride(1) == 1
        is_b_row = b.stride(1) == 1
        lda = a.stride(0) if is_a_row else a.stride(1)
        ldb = b.stride(0) if is_b_row else b.stride(1)
        ldc = c.stride(0)
        lda_pow2_div = _matmul.largest_pow2_divisor(lda)
        ldb_pow2_div = _matmul.largest_pow2_divisor(ldb)
        ldc_pow2_div = _matmul.largest_pow2_divisor(ldc)
        is_tk_div_k = K % 32 == 0
        key = (device, dtype, is_a_row, is_b_row, lda_pow2_div, ldb_pow2_div,
               ldc_pow2_div, is_tk_div_k)
        if key not in _matmul._kernels:
            defines = {
                'TYPE': dtype,
                'STRIDE_AM': 'lda' if is_a_row else '1',
                'STRIDE_AK': '1' if is_a_row else 'lda',
                'STRIDE_BK': 'ldb' if is_b_row else '1',
                'STRIDE_BN': '1' if is_b_row else 'ldb',
                'LDA_POW2_DIV': lda_pow2_div,
                'LDB_POW2_DIV': ldb_pow2_div,
                'LDC_POW2_DIV': ldc_pow2_div,
                'TM': _matmul.TM,
                'TN': _matmul.TN,
                'TK': _matmul.TK,
                'TZ': _matmul.TZ,
                'IS_TK_DIV_K': is_tk_div_k
            }
            _matmul._kernels[key] = triton.kernel(_matmul.src,
                                                  device,
                                                  num_warps=_matmul.num_warps,
                                                  defines=defines)
        kernel = _matmul._kernels[key]
        # # locks for split-k
        if device not in _matmul._locks:
            _matmul._locks[device] = torch.zeros(1024 * 1024,
                                                 dtype=torch.int32,
                                                 device=device)
        locks = _matmul._locks[device]
        # enqueue
        alpha = 1.
        args = [
            a.data_ptr(),
            b.data_ptr(),
            c.data_ptr(), alpha, M, N, K, lda, ldb, ldc,
            locks.data_ptr()
        ]
        grid = lambda opt: [
            triton.cdiv(M, opt.TM) * triton.cdiv(N, opt.TN), 1, 1
        ]
        kernel(*args, grid=grid)
        return c

    @staticmethod
    def forward(ctx, a, b):
        c = _matmul._call(a, b)
        return c
コード例 #8
0
    class _softmax_xent_loss_in_place(torch.autograd.Function):
        """This modifies logits in place, turning them into negative logprobs
        on the forward pass.  It should not copy the logits at all.
        """

        fwd_src = triton.read(
            os.path.join(os.path.dirname(__file__),
                         "softmax_xent_kernels_in_place.c"),
            kernel_names=["softmax_fwd"],
        )
        bwd_src = triton.read(
            os.path.join(os.path.dirname(__file__),
                         "softmax_xent_kernels_in_place.c"),
            kernel_names=["softmax_bwd"],
        )

        # Need TILE = n_vocab for this approach to work:
        input_config_to_kernel_fwd = {}
        input_config_to_kernel_bwd = {}

        @classmethod
        def forward(cls, ctx, logits, indices):
            n_vocab = logits.shape[-1]
            assert indices.dtype == torch.int64
            assert (
                n_vocab %
                128 == 0), "Number of logit options must be divisible by 128."

            if not (logits.dtype, n_vocab) in cls.input_config_to_kernel_fwd:
                cls.input_config_to_kernel_fwd[(logits.dtype,
                                                n_vocab)] = triton.kernel(
                                                    cls.fwd_src,
                                                    device=logits.device,
                                                    defines={
                                                        "TILE": n_vocab,
                                                        "TYPE": logits.dtype,
                                                    },
                                                    num_warps=[4],
                                                )
            kernel_fwd = cls.input_config_to_kernel_fwd[(logits.dtype,
                                                         n_vocab)]

            result = torch.zeros_like(indices, dtype=logits.dtype).cuda()
            grid = lambda opt: (logits.shape[0], )
            kernel_fwd(logits.data_ptr(),
                       indices.data_ptr(),
                       result.data_ptr(),
                       grid=grid)
            # logits -> neg_logprobs via an in place modification by kernel_fwd
            ctx.save_for_backward(logits, indices)

            return result

        @classmethod
        def backward(cls, ctx, dneg_logprobs):
            """We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k]
            so we initialize the gradient as neg_logprobs, so we can just exponentiate
            to get p[k], which is most of what we need...  neg_logprobs will be
            modified in place to become the gradient we want
            """
            neg_logprobs, indices = ctx.saved_tensors
            assert indices.dtype == torch.int64
            assert (
                dneg_logprobs.dtype == neg_logprobs.dtype
            ), f"Backward flowing derivatives of type {dneg_logprobs.dtype} != logits type {neg_logprobs.dtype}"
            n_vocab = neg_logprobs.shape[-1]
            N = neg_logprobs.numel()

            if not (neg_logprobs.dtype,
                    n_vocab) in cls.input_config_to_kernel_bwd:
                cls.input_config_to_kernel_bwd[(neg_logprobs.dtype,
                                                n_vocab)] = triton.kernel(
                                                    cls.bwd_src,
                                                    device=neg_logprobs.device,
                                                    defines={
                                                        "TILE": n_vocab,
                                                        "TYPE":
                                                        neg_logprobs.dtype
                                                    },
                                                    num_warps=[4],
                                                )
            kernel_bwd = cls.input_config_to_kernel_bwd[(neg_logprobs.dtype,
                                                         n_vocab)]
            grid = lambda opt: (triton.cdiv(N, opt.TILE), )

            # neg_logprobs will be modified in place to become our gradient:
            kernel_bwd(
                neg_logprobs.data_ptr(),
                indices.data_ptr(),
                dneg_logprobs.data_ptr(),
                grid=grid,
            )

            return neg_logprobs, torch.zeros_like(indices)
コード例 #9
0
import triton
import torch
import os

fwd_src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'),
                      kernel_names=['forward'])
fwd_kernels = dict()

bwd_src = triton.read(os.path.join(os.path.dirname(__file__), 'softmax.c'),
                      kernel_names=['backward'])
bwd_kernels = dict()


class _softmax(torch.autograd.Function):
    @staticmethod
    def next_power_of_2(n):
        n -= 1
        n |= n >> 1
        n |= n >> 2
        n |= n >> 4
        n |= n >> 8
        n |= n >> 16
        n += 1
        return n

    @staticmethod
    def make_lut(layout, block, device):
        _empty = torch.tensor([], dtype=torch.int64, device=layout.device)
        sizes = _empty.clone()
        # sizes along rows
        for h in range(layout.shape[0]):