예제 #1
0
def _matmul(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn,
            stride_cm, stride_cn, **META):
    # extract meta-parameters
    BLOCK_M = META['BLOCK_M']
    BLOCK_N = META['BLOCK_N']
    BLOCK_K = META['BLOCK_K']
    GROUP_M = 8
    # matrix multiplication
    pid = tl.program_id(0)
    grid_m = (M + BLOCK_M - 1) // BLOCK_M
    grid_n = (N + BLOCK_N - 1) // BLOCK_N
    # re-order program ID for better L2 performance
    width = GROUP_M * grid_n
    group_id = pid // width
    group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
    pid_m = group_id * GROUP_M + (pid % group_size)
    pid_n = (pid % width) // (group_size)
    # do matrix multiplication
    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    rk = tl.arange(0, BLOCK_K)
    A = A + (rm[:, None] * stride_am + rk[None, :] * stride_ak)
    B = B + (rk[:, None] * stride_bk + rn[None, :] * stride_bn)
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    for k in range(K, 0, -BLOCK_K):
        a = tl.load(A)
        b = tl.load(B)
        acc += tl.dot(a, b)
        A += BLOCK_K * stride_ak
        B += BLOCK_K * stride_bk
    # triton can accept arbitrary activation function
    # via metaparameters!
    if META['ACTIVATION']:
        acc = META['ACTIVATION'](acc)
    # rematerialize rm and rn to save registers
    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
    mask = (rm[:, None] < M) & (rn[None, :] < N)
    tl.store(C, acc, mask=mask)
예제 #2
0
def _kernel(A, B, C, stride_za, stride_ha, stride_ma, stride_ka, stride_zb,
            stride_hb, stride_kb, stride_nb, stride_zc, stride_hc, stride_mc,
            stride_nc, DS0, DS1, SDD_K, SDD_off_width, lut, locks, nlocks,
            **meta):
    TM = meta['TM']
    TN = meta['TN']
    TK = meta['TK']
    TZ = meta['TZ']
    BLOCK = meta['BLOCK']
    #------------#
    #- Prologue -#
    #------------#
    pid0 = tl.program_id(0)
    pid1 = tl.program_id(1)
    pidz = tl.program_id(2)
    if meta['SDD']:
        pid1 = pid1 + SDD_off_width
        blockidm = tl.arange(0, TM) // BLOCK
        blockidn = tl.arange(0, TN) // BLOCK
        offlutm = blockidm * (TN // BLOCK) * 4
        offlutn = blockidn * 4
        header = lut + pid1 * (TM // BLOCK) * (TN // BLOCK) * 4
        z = tl.load(header + 0)
        i = tl.load(header + 1 + offlutm)
        j = tl.load(header + 2 + offlutn)
        AS1 = SDD_K // TZ
        lockid = tl.where(TZ > 1, 1, 0)
        offka = pid0 * AS1
        offkb = pid0 * AS1
        offmc = 0
        offnc = 0
        offpa = 0
        offpb = 0
        maxid = TZ
        offhc = 0
        offha = z
        offhb = z
        ram = i * BLOCK + (tl.arange(0, TM) % BLOCK)
        rbn = j * BLOCK + (tl.arange(0, TN) % BLOCK)
    else:
        header = lut + pid0 * 6
        offset = tl.load(header + 0)
        AS1 = tl.load(header + 1)
        column = tl.load(header + 2)
        depth = tl.load(header + 3)
        lockid = tl.load(header + 4)
        maxid = tl.load(header + 5)
        pinc = lut + offset
        offhc = depth
        if meta['DSD']:
            # output offset
            offnc = pid1 * TN
            offmc = column * TM
            offpc = 0
            # dense input offset
            offnb = pid1 * TN
            offkb = tl.load(pinc)
            offkb = tl.multiple_of(offkb, 8)  # compiler hint
            offpb = 0
            # sparse input offset
            offma = 0
            offka = 0
            offpa = tl.load(pinc + 1)
            offpa = tl.multiple_of(offpa, 8)  # compiler hint
            offpa = offpa * BLOCK * BLOCK
            offha = 0
            offhb = depth
        else:
            # output offset
            offmc = pid1 * TM
            offnc = column * TN
            offpc = 0
            # dense input offset
            offma = pid1 * TM
            offka = tl.load(pinc)
            offka = tl.multiple_of(offka, 8)  # compiler hint
            offpa = 0
            # sparse input offset
            offnb = 0
            offkb = 0
            offpb = tl.load(pinc + 1)
            offpb = tl.multiple_of(offpb, 8)  # compiler hint
            offpb = offpb * BLOCK * BLOCK
            offha = depth
            offhb = 0
        ram = offma + tl.arange(0, TM)
        rbn = offnb + tl.arange(0, TN)

    # initialize a, b pointers
    rka = offka + tl.arange(0, TK)
    rkb = offkb + tl.arange(0, TK)
    pa = A + pidz * stride_za + offha * stride_ha + offpa + ram[:,
                                                                None] * stride_ma + rka[
                                                                    None, :] * stride_ka
    pb = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn[
        None, :] * stride_nb + rkb[:, None] * stride_kb
    if meta['DDS']:
        checkam = ram[:, None] < DS0
    else:
        checkam = AS1 > 0
    if meta['DSD']:
        checkbn = rbn[None, :] < DS0
    else:
        checkbn = AS1 > 0
    a = tl.load(pa, mask=checkam, other=0.)
    b = tl.load(pb, mask=checkbn, other=0.)

    ## ---------------- ##
    ##    Inner Loop    ##
    ## ---------------- ##
    acc = tl.zeros((TM, TN), dtype=tl.float32)
    for k in range(AS1, 0, -TK):
        acc += tl.dot(a, b)
        if meta['SDD']:
            inc_a = TK * stride_ka
            inc_b = TK * stride_kb
        else:
            pinc += 2
        if meta['DSD']:
            inc_b = tl.load(pinc)
            inc_a = tl.load(pinc + 1)
            inc_b = tl.multiple_of(inc_b, 8)
            inc_a = tl.multiple_of(inc_a, 8)
            inc_b = inc_b * stride_kb
        if meta['DDS']:
            inc_a = tl.load(pinc)
            inc_b = tl.load(pinc + 1)
            inc_a = tl.multiple_of(inc_a, 8)
            inc_b = tl.multiple_of(inc_b, 8)
            inc_a = inc_a * stride_ka
        pa += inc_a
        pb += inc_b
        # pre-fetch
        checkak = k > TK
        checkbk = k > TK
        checka = checkam & checkak
        checkb = checkbn & checkbk
        a = tl.load(pa, mask=checka)
        b = tl.load(pb, mask=checkb)
    c = acc.to(C.dtype.element_ty)

    if meta['SDD']:
        checkc = True
        rr_blockidm = tl.arange(0, TM) // BLOCK
        rr_blockidn = tl.arange(0, TN) // BLOCK
        rr_offlutm = rr_blockidm * (TN // BLOCK) * 4
        rr_offlutn = rr_blockidn * 4
        off_bkid = 3 + rr_offlutm[:, None] + rr_offlutn[None, :]
        bkid = tl.load(header + off_bkid)
        offpc = bkid * BLOCK * BLOCK
        rcm = tl.arange(0, TM) % BLOCK
        rcn = tl.arange(0, TN) % BLOCK
    else:
        rcm = offmc + tl.arange(0, TM)
        rcn = offnc + tl.arange(0, TN)
    if meta['DSD']:
        checkc = rcn[None, :] < DS0
    if meta['DDS']:
        checkc = rcm[:, None] < DS0

    pc = C + offpc + offhc * stride_hc + pidz * stride_zc + rcm[:,
                                                                None] * stride_mc + rcn[
                                                                    None, :] * stride_nc
    # write-back directly
    if lockid == 0:
        tl.store(pc, c, mask=checkc)
    # accumulate partial results using spin-locks
    else:
        plock = locks + tl.program_id(2) * nlocks * tl.num_programs(
            1) + tl.program_id(1) * nlocks + lockid - 1
        pcount = plock + tl.num_programs(2) * tl.num_programs(1) * nlocks
        while tl.atomic_cas(plock, 0, 1) == 1:
            pass
        count = tl.load(pcount)
        if count == 0:
            tl.store(pc, c, mask=checkc)
        else:
            d = tl.load(pc, mask=checkc)
            tl.store(pc, d + c, mask=checkc)
        tl.atomic_xchg(pcount, (count + 1) % maxid)
        tl.atomic_xchg(plock, 0)
예제 #3
0
def _kernel(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn,
            stride_cm, stride_cn, LOCKS, **META):
    # extract meta-parameters
    BLOCK_M = META['BLOCK_M']
    BLOCK_N = META['BLOCK_N']
    BLOCK_K = META['BLOCK_K']
    GROUP_M = META['GROUP_M']
    SPLIT_K = META['SPLIT_K']
    # matrix multiplication
    pid = tl.program_id(0)
    pid_z = tl.program_id(1)
    grid_m = (M + BLOCK_M - 1) // BLOCK_M
    grid_n = (N + BLOCK_N - 1) // BLOCK_N
    # re-order program ID for better L2 performance
    width = GROUP_M * grid_n
    group_id = pid // width
    group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
    pid_m = group_id * GROUP_M + (pid % group_size)
    pid_n = (pid % width) // (group_size)
    # do matrix multiplication
    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    rk = tl.arange(0, BLOCK_K)
    # pointers
    K = K // SPLIT_K
    A = A + (pid_z * K * stride_ak + rm[:, None] * stride_am +
             rk[None, :] * stride_ak)
    B = B + (pid_z * K * stride_bk + rk[:, None] * stride_bk +
             rn[None, :] * stride_bn)
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    for k in range(K, 0, -BLOCK_K):
        if META['EVEN_K']:
            a = tl.load(A)
            b = tl.load(B)
        else:
            a = tl.load(A, mask=rk[None, :] < k, other=0.)
            b = tl.load(B, mask=rk[:, None] < k, other=0.)
        acc += tl.dot(a, b)
        A += BLOCK_K * stride_ak
        B += BLOCK_K * stride_bk
    acc = acc.to(tl.float16)
    # rematerialize rm and rn to save registers
    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
    mask = (rm < M)[:, None] & (rn < N)[None, :]
    # handles write-back with reduction-splitting
    if SPLIT_K == 1:
        tl.store(C, acc, mask=mask)
    else:
        LOCKS = LOCKS + tl.program_id(0)
        COUNT = LOCKS + tl.num_programs(0)
        while tl.atomic_cas(LOCKS, 0, 1) == 1:
            pass
        count = tl.load(COUNT)
        if count == 0:
            tl.store(C, acc, mask=mask)
        else:
            curr = tl.load(C, mask=mask, other=0.)
            tl.store(C, acc + curr, mask=mask)
        tl.atomic_xchg(COUNT, (count + 1) % SPLIT_K)
        tl.atomic_xchg(LOCKS, 0)
예제 #4
0
def matmul_kernel(
    # Pointers to matrices
    a_ptr, b_ptr, c_ptr,
    # Matrix dimensions
    M, N, K,
    # The stride variables represent how much to increase the ptr by when moving by 1
    # element in a particular dimension. E.g. stride_am is how much to increase a_ptr
    # by to get the element one row down (A has M rows)
    stride_am, stride_ak,
    stride_bk, stride_bn,
    stride_cm, stride_cn,
    # Meta-parameters
    **meta,
):
    """Kernel for computing the matmul C = A x B.
    A has shape (M, K), B has shape (K, N) and C has shape (M, N)
    """
    # extract meta-parameters
    BLOCK_SIZE_M = meta['BLOCK_SIZE_M']
    BLOCK_SIZE_N = meta['BLOCK_SIZE_N']
    BLOCK_SIZE_K = meta['BLOCK_SIZE_K']
    GROUP_SIZE_M = 8

    # -----------------------------------------------------------
    # Map program ids `pid` to the block of C it should compute.
    # This is done in a grouped ordering to promote L2 data reuse
    # See above `L2 Cache Optimizations` section for details
    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
    num_pid_in_group = GROUP_SIZE_M * num_pid_n 
    group_id = pid // num_pid_in_group 
    first_pid_m = group_id * GROUP_SIZE_M 
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) 
    pid_m = first_pid_m + (pid % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    # ----------------------------------------------------------
    # Create pointers for the first blocks of A and B.
    # We will advance this pointer as we move in the K direction 
    # and accumulate
    # a_ptrs is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
    # b_ptrs is a block of [BLOCK_SIZE_K, BLOCK_SIZE_n] pointers
    # see above `Pointer Arithmetics` section for details
    offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    offs_k = tl.arange(0, BLOCK_SIZE_K)
    a_ptrs = a_ptr + (offs_am[:, None]*stride_am + offs_k [None, :]*stride_ak)
    b_ptrs = b_ptr + (offs_k [:, None]*stride_bk + offs_bn[None, :]*stride_bn)

    # -----------------------------------------------------------
    # Iterate to compute a block of the C matrix
    # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
    # of fp32 values for higher accuracy.
    # `accumulator` will be converted back to fp16 after the loop
    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
    for k in range(0, K, BLOCK_SIZE_K):
        # Note that for simplicity, we don't apply a mask here. 
        # This means that if K is not a multiple of BLOCK_SIZE_K, 
        # this will access out-of-bounds memory and produce an
        # error or (worse!) incorrect results.
        a = tl.load(a_ptrs)
        b = tl.load(b_ptrs)
        # We accumulate along the K dimension
        accumulator += tl.dot(a, b)
        # Advance the ptrs to the next K block
        a_ptrs += BLOCK_SIZE_K * stride_ak
        b_ptrs += BLOCK_SIZE_K * stride_bk
    # you can fuse arbitrary activation functions here
    # while the accumulator is still in FP32 !
    if meta['ACTIVATION']: 
        accumulator = meta['ACTIVATION'](accumulator)
    c = accumulator.to(tl.float16)

    # -----------------------------------------------------------
    # Write back the block of the output matrix C
    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
    tl.store(c_ptrs, c, mask=c_mask)