def _matmul(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, **META): # extract meta-parameters BLOCK_M = META['BLOCK_M'] BLOCK_N = META['BLOCK_N'] BLOCK_K = META['BLOCK_K'] GROUP_M = 8 # matrix multiplication pid = tl.program_id(0) grid_m = (M + BLOCK_M - 1) // BLOCK_M grid_n = (N + BLOCK_N - 1) // BLOCK_N # re-order program ID for better L2 performance width = GROUP_M * grid_n group_id = pid // width group_size = min(grid_m - group_id * GROUP_M, GROUP_M) pid_m = group_id * GROUP_M + (pid % group_size) pid_n = (pid % width) // (group_size) # do matrix multiplication rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) rk = tl.arange(0, BLOCK_K) A = A + (rm[:, None] * stride_am + rk[None, :] * stride_ak) B = B + (rk[:, None] * stride_bk + rn[None, :] * stride_bn) acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k in range(K, 0, -BLOCK_K): a = tl.load(A) b = tl.load(B) acc += tl.dot(a, b) A += BLOCK_K * stride_ak B += BLOCK_K * stride_bk # triton can accept arbitrary activation function # via metaparameters! if META['ACTIVATION']: acc = META['ACTIVATION'](acc) # rematerialize rm and rn to save registers rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) mask = (rm[:, None] < M) & (rn[None, :] < N) tl.store(C, acc, mask=mask)
def _kernel(A, B, C, stride_za, stride_ha, stride_ma, stride_ka, stride_zb, stride_hb, stride_kb, stride_nb, stride_zc, stride_hc, stride_mc, stride_nc, DS0, DS1, SDD_K, SDD_off_width, lut, locks, nlocks, **meta): TM = meta['TM'] TN = meta['TN'] TK = meta['TK'] TZ = meta['TZ'] BLOCK = meta['BLOCK'] #------------# #- Prologue -# #------------# pid0 = tl.program_id(0) pid1 = tl.program_id(1) pidz = tl.program_id(2) if meta['SDD']: pid1 = pid1 + SDD_off_width blockidm = tl.arange(0, TM) // BLOCK blockidn = tl.arange(0, TN) // BLOCK offlutm = blockidm * (TN // BLOCK) * 4 offlutn = blockidn * 4 header = lut + pid1 * (TM // BLOCK) * (TN // BLOCK) * 4 z = tl.load(header + 0) i = tl.load(header + 1 + offlutm) j = tl.load(header + 2 + offlutn) AS1 = SDD_K // TZ lockid = tl.where(TZ > 1, 1, 0) offka = pid0 * AS1 offkb = pid0 * AS1 offmc = 0 offnc = 0 offpa = 0 offpb = 0 maxid = TZ offhc = 0 offha = z offhb = z ram = i * BLOCK + (tl.arange(0, TM) % BLOCK) rbn = j * BLOCK + (tl.arange(0, TN) % BLOCK) else: header = lut + pid0 * 6 offset = tl.load(header + 0) AS1 = tl.load(header + 1) column = tl.load(header + 2) depth = tl.load(header + 3) lockid = tl.load(header + 4) maxid = tl.load(header + 5) pinc = lut + offset offhc = depth if meta['DSD']: # output offset offnc = pid1 * TN offmc = column * TM offpc = 0 # dense input offset offnb = pid1 * TN offkb = tl.load(pinc) offkb = tl.multiple_of(offkb, 8) # compiler hint offpb = 0 # sparse input offset offma = 0 offka = 0 offpa = tl.load(pinc + 1) offpa = tl.multiple_of(offpa, 8) # compiler hint offpa = offpa * BLOCK * BLOCK offha = 0 offhb = depth else: # output offset offmc = pid1 * TM offnc = column * TN offpc = 0 # dense input offset offma = pid1 * TM offka = tl.load(pinc) offka = tl.multiple_of(offka, 8) # compiler hint offpa = 0 # sparse input offset offnb = 0 offkb = 0 offpb = tl.load(pinc + 1) offpb = tl.multiple_of(offpb, 8) # compiler hint offpb = offpb * BLOCK * BLOCK offha = depth offhb = 0 ram = offma + tl.arange(0, TM) rbn = offnb + tl.arange(0, TN) # initialize a, b pointers rka = offka + tl.arange(0, TK) rkb = offkb + tl.arange(0, TK) pa = A + pidz * stride_za + offha * stride_ha + offpa + ram[:, None] * stride_ma + rka[ None, :] * stride_ka pb = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn[ None, :] * stride_nb + rkb[:, None] * stride_kb if meta['DDS']: checkam = ram[:, None] < DS0 else: checkam = AS1 > 0 if meta['DSD']: checkbn = rbn[None, :] < DS0 else: checkbn = AS1 > 0 a = tl.load(pa, mask=checkam, other=0.) b = tl.load(pb, mask=checkbn, other=0.) ## ---------------- ## ## Inner Loop ## ## ---------------- ## acc = tl.zeros((TM, TN), dtype=tl.float32) for k in range(AS1, 0, -TK): acc += tl.dot(a, b) if meta['SDD']: inc_a = TK * stride_ka inc_b = TK * stride_kb else: pinc += 2 if meta['DSD']: inc_b = tl.load(pinc) inc_a = tl.load(pinc + 1) inc_b = tl.multiple_of(inc_b, 8) inc_a = tl.multiple_of(inc_a, 8) inc_b = inc_b * stride_kb if meta['DDS']: inc_a = tl.load(pinc) inc_b = tl.load(pinc + 1) inc_a = tl.multiple_of(inc_a, 8) inc_b = tl.multiple_of(inc_b, 8) inc_a = inc_a * stride_ka pa += inc_a pb += inc_b # pre-fetch checkak = k > TK checkbk = k > TK checka = checkam & checkak checkb = checkbn & checkbk a = tl.load(pa, mask=checka) b = tl.load(pb, mask=checkb) c = acc.to(C.dtype.element_ty) if meta['SDD']: checkc = True rr_blockidm = tl.arange(0, TM) // BLOCK rr_blockidn = tl.arange(0, TN) // BLOCK rr_offlutm = rr_blockidm * (TN // BLOCK) * 4 rr_offlutn = rr_blockidn * 4 off_bkid = 3 + rr_offlutm[:, None] + rr_offlutn[None, :] bkid = tl.load(header + off_bkid) offpc = bkid * BLOCK * BLOCK rcm = tl.arange(0, TM) % BLOCK rcn = tl.arange(0, TN) % BLOCK else: rcm = offmc + tl.arange(0, TM) rcn = offnc + tl.arange(0, TN) if meta['DSD']: checkc = rcn[None, :] < DS0 if meta['DDS']: checkc = rcm[:, None] < DS0 pc = C + offpc + offhc * stride_hc + pidz * stride_zc + rcm[:, None] * stride_mc + rcn[ None, :] * stride_nc # write-back directly if lockid == 0: tl.store(pc, c, mask=checkc) # accumulate partial results using spin-locks else: plock = locks + tl.program_id(2) * nlocks * tl.num_programs( 1) + tl.program_id(1) * nlocks + lockid - 1 pcount = plock + tl.num_programs(2) * tl.num_programs(1) * nlocks while tl.atomic_cas(plock, 0, 1) == 1: pass count = tl.load(pcount) if count == 0: tl.store(pc, c, mask=checkc) else: d = tl.load(pc, mask=checkc) tl.store(pc, d + c, mask=checkc) tl.atomic_xchg(pcount, (count + 1) % maxid) tl.atomic_xchg(plock, 0)
def _kernel(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, LOCKS, **META): # extract meta-parameters BLOCK_M = META['BLOCK_M'] BLOCK_N = META['BLOCK_N'] BLOCK_K = META['BLOCK_K'] GROUP_M = META['GROUP_M'] SPLIT_K = META['SPLIT_K'] # matrix multiplication pid = tl.program_id(0) pid_z = tl.program_id(1) grid_m = (M + BLOCK_M - 1) // BLOCK_M grid_n = (N + BLOCK_N - 1) // BLOCK_N # re-order program ID for better L2 performance width = GROUP_M * grid_n group_id = pid // width group_size = min(grid_m - group_id * GROUP_M, GROUP_M) pid_m = group_id * GROUP_M + (pid % group_size) pid_n = (pid % width) // (group_size) # do matrix multiplication rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) rk = tl.arange(0, BLOCK_K) # pointers K = K // SPLIT_K A = A + (pid_z * K * stride_ak + rm[:, None] * stride_am + rk[None, :] * stride_ak) B = B + (pid_z * K * stride_bk + rk[:, None] * stride_bk + rn[None, :] * stride_bn) acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k in range(K, 0, -BLOCK_K): if META['EVEN_K']: a = tl.load(A) b = tl.load(B) else: a = tl.load(A, mask=rk[None, :] < k, other=0.) b = tl.load(B, mask=rk[:, None] < k, other=0.) acc += tl.dot(a, b) A += BLOCK_K * stride_ak B += BLOCK_K * stride_bk acc = acc.to(tl.float16) # rematerialize rm and rn to save registers rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) mask = (rm < M)[:, None] & (rn < N)[None, :] # handles write-back with reduction-splitting if SPLIT_K == 1: tl.store(C, acc, mask=mask) else: LOCKS = LOCKS + tl.program_id(0) COUNT = LOCKS + tl.num_programs(0) while tl.atomic_cas(LOCKS, 0, 1) == 1: pass count = tl.load(COUNT) if count == 0: tl.store(C, acc, mask=mask) else: curr = tl.load(C, mask=mask, other=0.) tl.store(C, acc + curr, mask=mask) tl.atomic_xchg(COUNT, (count + 1) % SPLIT_K) tl.atomic_xchg(LOCKS, 0)
def matmul_kernel( # Pointers to matrices a_ptr, b_ptr, c_ptr, # Matrix dimensions M, N, K, # The stride variables represent how much to increase the ptr by when moving by 1 # element in a particular dimension. E.g. stride_am is how much to increase a_ptr # by to get the element one row down (A has M rows) stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, # Meta-parameters **meta, ): """Kernel for computing the matmul C = A x B. A has shape (M, K), B has shape (K, N) and C has shape (M, N) """ # extract meta-parameters BLOCK_SIZE_M = meta['BLOCK_SIZE_M'] BLOCK_SIZE_N = meta['BLOCK_SIZE_N'] BLOCK_SIZE_K = meta['BLOCK_SIZE_K'] GROUP_SIZE_M = 8 # ----------------------------------------------------------- # Map program ids `pid` to the block of C it should compute. # This is done in a grouped ordering to promote L2 data reuse # See above `L2 Cache Optimizations` section for details pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) num_pid_in_group = GROUP_SIZE_M * num_pid_n group_id = pid // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) pid_m = first_pid_m + (pid % group_size_m) pid_n = (pid % num_pid_in_group) // group_size_m # ---------------------------------------------------------- # Create pointers for the first blocks of A and B. # We will advance this pointer as we move in the K direction # and accumulate # a_ptrs is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers # b_ptrs is a block of [BLOCK_SIZE_K, BLOCK_SIZE_n] pointers # see above `Pointer Arithmetics` section for details offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + (offs_am[:, None]*stride_am + offs_k [None, :]*stride_ak) b_ptrs = b_ptr + (offs_k [:, None]*stride_bk + offs_bn[None, :]*stride_bn) # ----------------------------------------------------------- # Iterate to compute a block of the C matrix # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block # of fp32 values for higher accuracy. # `accumulator` will be converted back to fp16 after the loop accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, K, BLOCK_SIZE_K): # Note that for simplicity, we don't apply a mask here. # This means that if K is not a multiple of BLOCK_SIZE_K, # this will access out-of-bounds memory and produce an # error or (worse!) incorrect results. a = tl.load(a_ptrs) b = tl.load(b_ptrs) # We accumulate along the K dimension accumulator += tl.dot(a, b) # Advance the ptrs to the next K block a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk # you can fuse arbitrary activation functions here # while the accumulator is still in FP32 ! if meta['ACTIVATION']: accumulator = meta['ACTIVATION'](accumulator) c = accumulator.to(tl.float16) # ----------------------------------------------------------- # Write back the block of the output matrix C offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) tl.store(c_ptrs, c, mask=c_mask)