Пример #1
0
def _backward(PROBS, IDX, DPROBS, N, **meta):
    BLOCK = meta['BLOCK']
    row = tl.program_id(0)
    cols = tl.arange(0, BLOCK)
    idx = tl.load(IDX + row)
    # pointers to probs
    PROBS = PROBS + row * N + cols
    # We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k]
    # and we have -log(p[k]) stored in PROBS, so this is easy
    probs = -tl.load(PROBS, mask=cols < N, other=float('inf'))
    probs = tl.exp(probs.to(tl.float32))
    delta = cols == idx
    # write result in-place in PROBS
    dout = tl.load(DPROBS + row)
    din = (probs - delta) * dout
    tl.store(PROBS, din.to(tl.float16), mask=cols < N)
Пример #2
0
def _softmax(
        Y,  #output pointer 
        X,  #input pointer
        stride_xm,  #stride tells us which to normalize
        stride_ym,  #
        M,  #num of rows
        N,  #num of cols
        **meta):

    ## Note : Here stride is basically the amount you have to
    ## move in memory to get to the next element
    ## Ex: so if x is 2D, lets say x = torch.randn(10, 20),
    ##  then x[1, 1].data_ptr() = x[1, 0].data_ptr() + x.stride(1)

    # 1 program id per row index
    m = tl.program_id(0)

    # col indices
    # here BLOCK is the smallest power of two greater than `N`
    # In triton, the block size always has to be the power of 2
    # in general powers of 2 nicely fit / align with cache lines

    n = tl.arange(0, meta['BLOCK'])

    # the memory address of all the elements
    # that we want to load can be computed as follows
    #   m * stride_xm + n is how forward we want to move for the next pointer

    X = X + m * stride_xm + n

    # x is the vector of values loaded.
    # - float(inf) is so we dont go over the edge
    # if we do , penalize it and make it come back
    x = tl.load(X, mask=n < N, other=-float('inf'))

    # Substract maximum for numerical stability
    z = x - tl.max(x, axis=0)
    # Note that exponentials in Triton are fast
    # but approximate (i.e., think __expf in CUDA)
    num = tl.exp(z)
    denom = tl.sum(num, axis=0)
    y = num / denom
    # Write back to Y
    Y = Y + m * stride_ym + n
    tl.store(Y, y, mask=n < N)
Пример #3
0
def _add(
    X,  # *Pointer* to first input vector
    Y,  # *Pointer* to second input vector
    Z,  # *Pointer* to output vector
    N,  # Size of the vector
    **meta  # Optional meta-parameters for the kernel
):
    pid = tl.program_id(0)
    # Create an offset for the blocks of pointers to be
    # processed by this program instance
    offsets = pid * meta['BLOCK'] + tl.arange(0, meta['BLOCK'])
    # Create a mask to guard memory operations against
    # out-of-bounds accesses
    mask = offsets < N
    # Load x
    x = tl.load(X + offsets, mask=mask)
    y = tl.load(Y + offsets, mask=mask)
    # Write back x + y
    z = x + y
    tl.store(Z + offsets, z)
Пример #4
0
def _forward(LOGITS, PROBS, IDX, LOSS, N, **meta):
    BLOCK = meta['BLOCK']
    row = tl.program_id(0)
    cols = tl.arange(0, BLOCK)
    idx = tl.load(IDX + row)
    # pointers to logit and probs
    LOGITS = LOGITS + row * N + cols
    WRIT_PROBS = PROBS + row * N + cols
    READ_PROBS = PROBS + row * N + idx
    # write-back negative log-probs
    logits = tl.load(LOGITS, mask=cols < N, other=-float('inf'))
    logits = logits.to(tl.float32)
    logits = logits - tl.max(logits, 0)
    probs = tl.log(tl.sum(tl.exp(logits), 0)) - logits
    tl.store(WRIT_PROBS, probs, mask=cols < N)
    # There is a bug in the compiler, which fails to insert a barrier here.
    # We add it explicitly for now. Will be fixed soon.
    tl.debug_barrier()
    # write-back loss
    probs = tl.load(READ_PROBS)
    tl.store(LOSS + row, probs)
Пример #5
0
def _matmul(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn,
            stride_cm, stride_cn, **META):
    # extract meta-parameters
    BLOCK_M = META['BLOCK_M']
    BLOCK_N = META['BLOCK_N']
    BLOCK_K = META['BLOCK_K']
    GROUP_M = 8
    # matrix multiplication
    pid = tl.program_id(0)
    grid_m = (M + BLOCK_M - 1) // BLOCK_M
    grid_n = (N + BLOCK_N - 1) // BLOCK_N
    # re-order program ID for better L2 performance
    width = GROUP_M * grid_n
    group_id = pid // width
    group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
    pid_m = group_id * GROUP_M + (pid % group_size)
    pid_n = (pid % width) // (group_size)
    # do matrix multiplication
    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    rk = tl.arange(0, BLOCK_K)
    A = A + (rm[:, None] * stride_am + rk[None, :] * stride_ak)
    B = B + (rk[:, None] * stride_bk + rn[None, :] * stride_bn)
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    for k in range(K, 0, -BLOCK_K):
        a = tl.load(A)
        b = tl.load(B)
        acc += tl.dot(a, b)
        A += BLOCK_K * stride_ak
        B += BLOCK_K * stride_bk
    # triton can accept arbitrary activation function
    # via metaparameters!
    if META['ACTIVATION']:
        acc = META['ACTIVATION'](acc)
    # rematerialize rm and rn to save registers
    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
    mask = (rm[:, None] < M) & (rn[None, :] < N)
    tl.store(C, acc, mask=mask)
def _add(
    X,  # *pointer* to input vector 1
    Y,  # *pointer * to input vector 2
    Z,  # *pointer* to output vector
    N,  # Size of the vector
    **meta  #Optional meta parameters for the kernel 
    # In meta, block size and compile time constants etc parameters
):
    #Roughly the start of the block
    pid = tl.program_id(0)

    #We dont worry about threads here. We are writing program for
    #whole block for now

    # Create offset for the block of pointers to be processed
    # by the program instance
    #Read all the block of pointers
    offsets = pid * meta['BLOCK'] + tl.arange(0, meta['BLOCK'])

    #Create a mask to guard memory operations against
    #out-of-bound accesses

    #mask defines the shape. It defines what to load and what to
    #leave undefined
    mask = offsets < N

    #Load x
    #scale + arange -> yields block of pointers
    load_x_ptrs = X + offsets
    load_y_ptrs = Y + offsets

    x = tl.load(load_x_ptrs, mask=mask)  # x <- tensor of data
    y = tl.load(load_y_ptrs, mask=mask)

    #Write x + y
    z = x + y
    tl.store(Z + offsets, z)
Пример #7
0
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride,
                   n_cols, **meta):
    # The rows of the softmax are independent, so we parallelize across those
    row_idx = tl.program_id(0)
    BLOCK_SIZE = meta['BLOCK_SIZE']
    # The stride represents how much we need to increase the pointer to advance 1 row
    row_start_ptr = input_ptr + row_idx * input_row_stride

    # The block size is the next power of two greater than n_cols, so we can fit each
    # row in a single block
    col_offsets = tl.arange(0, BLOCK_SIZE)
    input_ptrs = row_start_ptr + col_offsets
    # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
    row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf'))
    # Substract maximum for numerical stability
    row_minus_max = row - tl.max(row, axis=0)
    # Note that exponentials in Triton are fast but approximate (i.e., think __expf in CUDA)
    numerator = tl.exp(row_minus_max)
    denominator = tl.sum(numerator, axis=0)
    softmax_output = numerator / denominator
    # Write back output to DRAM
    output_row_start_ptr = output_ptr + row_idx * output_row_stride
    output_ptrs = output_row_start_ptr + col_offsets
    tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)
Пример #8
0
def _kernel(A, B, C, stride_za, stride_ha, stride_ma, stride_ka, stride_zb,
            stride_hb, stride_kb, stride_nb, stride_zc, stride_hc, stride_mc,
            stride_nc, DS0, DS1, SDD_K, SDD_off_width, lut, locks, nlocks,
            **meta):
    TM = meta['TM']
    TN = meta['TN']
    TK = meta['TK']
    TZ = meta['TZ']
    BLOCK = meta['BLOCK']
    #------------#
    #- Prologue -#
    #------------#
    pid0 = tl.program_id(0)
    pid1 = tl.program_id(1)
    pidz = tl.program_id(2)
    if meta['SDD']:
        pid1 = pid1 + SDD_off_width
        blockidm = tl.arange(0, TM) // BLOCK
        blockidn = tl.arange(0, TN) // BLOCK
        offlutm = blockidm * (TN // BLOCK) * 4
        offlutn = blockidn * 4
        header = lut + pid1 * (TM // BLOCK) * (TN // BLOCK) * 4
        z = tl.load(header + 0)
        i = tl.load(header + 1 + offlutm)
        j = tl.load(header + 2 + offlutn)
        AS1 = SDD_K // TZ
        lockid = tl.where(TZ > 1, 1, 0)
        offka = pid0 * AS1
        offkb = pid0 * AS1
        offmc = 0
        offnc = 0
        offpa = 0
        offpb = 0
        maxid = TZ
        offhc = 0
        offha = z
        offhb = z
        ram = i * BLOCK + (tl.arange(0, TM) % BLOCK)
        rbn = j * BLOCK + (tl.arange(0, TN) % BLOCK)
    else:
        header = lut + pid0 * 6
        offset = tl.load(header + 0)
        AS1 = tl.load(header + 1)
        column = tl.load(header + 2)
        depth = tl.load(header + 3)
        lockid = tl.load(header + 4)
        maxid = tl.load(header + 5)
        pinc = lut + offset
        offhc = depth
        if meta['DSD']:
            # output offset
            offnc = pid1 * TN
            offmc = column * TM
            offpc = 0
            # dense input offset
            offnb = pid1 * TN
            offkb = tl.load(pinc)
            offkb = tl.multiple_of(offkb, 8)  # compiler hint
            offpb = 0
            # sparse input offset
            offma = 0
            offka = 0
            offpa = tl.load(pinc + 1)
            offpa = tl.multiple_of(offpa, 8)  # compiler hint
            offpa = offpa * BLOCK * BLOCK
            offha = 0
            offhb = depth
        else:
            # output offset
            offmc = pid1 * TM
            offnc = column * TN
            offpc = 0
            # dense input offset
            offma = pid1 * TM
            offka = tl.load(pinc)
            offka = tl.multiple_of(offka, 8)  # compiler hint
            offpa = 0
            # sparse input offset
            offnb = 0
            offkb = 0
            offpb = tl.load(pinc + 1)
            offpb = tl.multiple_of(offpb, 8)  # compiler hint
            offpb = offpb * BLOCK * BLOCK
            offha = depth
            offhb = 0
        ram = offma + tl.arange(0, TM)
        rbn = offnb + tl.arange(0, TN)

    # initialize a, b pointers
    rka = offka + tl.arange(0, TK)
    rkb = offkb + tl.arange(0, TK)
    pa = A + pidz * stride_za + offha * stride_ha + offpa + ram[:,
                                                                None] * stride_ma + rka[
                                                                    None, :] * stride_ka
    pb = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn[
        None, :] * stride_nb + rkb[:, None] * stride_kb
    if meta['DDS']:
        checkam = ram[:, None] < DS0
    else:
        checkam = AS1 > 0
    if meta['DSD']:
        checkbn = rbn[None, :] < DS0
    else:
        checkbn = AS1 > 0
    a = tl.load(pa, mask=checkam, other=0.)
    b = tl.load(pb, mask=checkbn, other=0.)

    ## ---------------- ##
    ##    Inner Loop    ##
    ## ---------------- ##
    acc = tl.zeros((TM, TN), dtype=tl.float32)
    for k in range(AS1, 0, -TK):
        acc += tl.dot(a, b)
        if meta['SDD']:
            inc_a = TK * stride_ka
            inc_b = TK * stride_kb
        else:
            pinc += 2
        if meta['DSD']:
            inc_b = tl.load(pinc)
            inc_a = tl.load(pinc + 1)
            inc_b = tl.multiple_of(inc_b, 8)
            inc_a = tl.multiple_of(inc_a, 8)
            inc_b = inc_b * stride_kb
        if meta['DDS']:
            inc_a = tl.load(pinc)
            inc_b = tl.load(pinc + 1)
            inc_a = tl.multiple_of(inc_a, 8)
            inc_b = tl.multiple_of(inc_b, 8)
            inc_a = inc_a * stride_ka
        pa += inc_a
        pb += inc_b
        # pre-fetch
        checkak = k > TK
        checkbk = k > TK
        checka = checkam & checkak
        checkb = checkbn & checkbk
        a = tl.load(pa, mask=checka)
        b = tl.load(pb, mask=checkb)
    c = acc.to(C.dtype.element_ty)

    if meta['SDD']:
        checkc = True
        rr_blockidm = tl.arange(0, TM) // BLOCK
        rr_blockidn = tl.arange(0, TN) // BLOCK
        rr_offlutm = rr_blockidm * (TN // BLOCK) * 4
        rr_offlutn = rr_blockidn * 4
        off_bkid = 3 + rr_offlutm[:, None] + rr_offlutn[None, :]
        bkid = tl.load(header + off_bkid)
        offpc = bkid * BLOCK * BLOCK
        rcm = tl.arange(0, TM) % BLOCK
        rcn = tl.arange(0, TN) % BLOCK
    else:
        rcm = offmc + tl.arange(0, TM)
        rcn = offnc + tl.arange(0, TN)
    if meta['DSD']:
        checkc = rcn[None, :] < DS0
    if meta['DDS']:
        checkc = rcm[:, None] < DS0

    pc = C + offpc + offhc * stride_hc + pidz * stride_zc + rcm[:,
                                                                None] * stride_mc + rcn[
                                                                    None, :] * stride_nc
    # write-back directly
    if lockid == 0:
        tl.store(pc, c, mask=checkc)
    # accumulate partial results using spin-locks
    else:
        plock = locks + tl.program_id(2) * nlocks * tl.num_programs(
            1) + tl.program_id(1) * nlocks + lockid - 1
        pcount = plock + tl.num_programs(2) * tl.num_programs(1) * nlocks
        while tl.atomic_cas(plock, 0, 1) == 1:
            pass
        count = tl.load(pcount)
        if count == 0:
            tl.store(pc, c, mask=checkc)
        else:
            d = tl.load(pc, mask=checkc)
            tl.store(pc, d + c, mask=checkc)
        tl.atomic_xchg(pcount, (count + 1) % maxid)
        tl.atomic_xchg(plock, 0)
Пример #9
0
def _kernel(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn,
            stride_cm, stride_cn, LOCKS, **META):
    # extract meta-parameters
    BLOCK_M = META['BLOCK_M']
    BLOCK_N = META['BLOCK_N']
    BLOCK_K = META['BLOCK_K']
    GROUP_M = META['GROUP_M']
    SPLIT_K = META['SPLIT_K']
    # matrix multiplication
    pid = tl.program_id(0)
    pid_z = tl.program_id(1)
    grid_m = (M + BLOCK_M - 1) // BLOCK_M
    grid_n = (N + BLOCK_N - 1) // BLOCK_N
    # re-order program ID for better L2 performance
    width = GROUP_M * grid_n
    group_id = pid // width
    group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
    pid_m = group_id * GROUP_M + (pid % group_size)
    pid_n = (pid % width) // (group_size)
    # do matrix multiplication
    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    rk = tl.arange(0, BLOCK_K)
    # pointers
    K = K // SPLIT_K
    A = A + (pid_z * K * stride_ak + rm[:, None] * stride_am +
             rk[None, :] * stride_ak)
    B = B + (pid_z * K * stride_bk + rk[:, None] * stride_bk +
             rn[None, :] * stride_bn)
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    for k in range(K, 0, -BLOCK_K):
        if META['EVEN_K']:
            a = tl.load(A)
            b = tl.load(B)
        else:
            a = tl.load(A, mask=rk[None, :] < k, other=0.)
            b = tl.load(B, mask=rk[:, None] < k, other=0.)
        acc += tl.dot(a, b)
        A += BLOCK_K * stride_ak
        B += BLOCK_K * stride_bk
    acc = acc.to(tl.float16)
    # rematerialize rm and rn to save registers
    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
    mask = (rm < M)[:, None] & (rn < N)[None, :]
    # handles write-back with reduction-splitting
    if SPLIT_K == 1:
        tl.store(C, acc, mask=mask)
    else:
        LOCKS = LOCKS + tl.program_id(0)
        COUNT = LOCKS + tl.num_programs(0)
        while tl.atomic_cas(LOCKS, 0, 1) == 1:
            pass
        count = tl.load(COUNT)
        if count == 0:
            tl.store(C, acc, mask=mask)
        else:
            curr = tl.load(C, mask=mask, other=0.)
            tl.store(C, acc + curr, mask=mask)
        tl.atomic_xchg(COUNT, (count + 1) % SPLIT_K)
        tl.atomic_xchg(LOCKS, 0)
Пример #10
0
 def kernel(X, Z, **meta):
     pid = tl.program_id(0)
     old = tl.atomic_add(X, pid)
     tl.store(Z + pid, old)
Пример #11
0
def matmul_kernel(
    # Pointers to matrices
    a_ptr, b_ptr, c_ptr,
    # Matrix dimensions
    M, N, K,
    # The stride variables represent how much to increase the ptr by when moving by 1
    # element in a particular dimension. E.g. stride_am is how much to increase a_ptr
    # by to get the element one row down (A has M rows)
    stride_am, stride_ak,
    stride_bk, stride_bn,
    stride_cm, stride_cn,
    # Meta-parameters
    **meta,
):
    """Kernel for computing the matmul C = A x B.
    A has shape (M, K), B has shape (K, N) and C has shape (M, N)
    """
    # extract meta-parameters
    BLOCK_SIZE_M = meta['BLOCK_SIZE_M']
    BLOCK_SIZE_N = meta['BLOCK_SIZE_N']
    BLOCK_SIZE_K = meta['BLOCK_SIZE_K']
    GROUP_SIZE_M = 8

    # -----------------------------------------------------------
    # Map program ids `pid` to the block of C it should compute.
    # This is done in a grouped ordering to promote L2 data reuse
    # See above `L2 Cache Optimizations` section for details
    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
    num_pid_in_group = GROUP_SIZE_M * num_pid_n 
    group_id = pid // num_pid_in_group 
    first_pid_m = group_id * GROUP_SIZE_M 
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) 
    pid_m = first_pid_m + (pid % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    # ----------------------------------------------------------
    # Create pointers for the first blocks of A and B.
    # We will advance this pointer as we move in the K direction 
    # and accumulate
    # a_ptrs is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
    # b_ptrs is a block of [BLOCK_SIZE_K, BLOCK_SIZE_n] pointers
    # see above `Pointer Arithmetics` section for details
    offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    offs_k = tl.arange(0, BLOCK_SIZE_K)
    a_ptrs = a_ptr + (offs_am[:, None]*stride_am + offs_k [None, :]*stride_ak)
    b_ptrs = b_ptr + (offs_k [:, None]*stride_bk + offs_bn[None, :]*stride_bn)

    # -----------------------------------------------------------
    # Iterate to compute a block of the C matrix
    # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
    # of fp32 values for higher accuracy.
    # `accumulator` will be converted back to fp16 after the loop
    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
    for k in range(0, K, BLOCK_SIZE_K):
        # Note that for simplicity, we don't apply a mask here. 
        # This means that if K is not a multiple of BLOCK_SIZE_K, 
        # this will access out-of-bounds memory and produce an
        # error or (worse!) incorrect results.
        a = tl.load(a_ptrs)
        b = tl.load(b_ptrs)
        # We accumulate along the K dimension
        accumulator += tl.dot(a, b)
        # Advance the ptrs to the next K block
        a_ptrs += BLOCK_SIZE_K * stride_ak
        b_ptrs += BLOCK_SIZE_K * stride_bk
    # you can fuse arbitrary activation functions here
    # while the accumulator is still in FP32 !
    if meta['ACTIVATION']: 
        accumulator = meta['ACTIVATION'](accumulator)
    c = accumulator.to(tl.float16)

    # -----------------------------------------------------------
    # Write back the block of the output matrix C
    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
    tl.store(c_ptrs, c, mask=c_mask)
Пример #12
0
 def kernel(X, Z, **meta):
     pid = tl.program_id(0)
     x = tl.load(X + pid)
     old = GENERATE_TEST_HERE