def _backward(PROBS, IDX, DPROBS, N, **meta): BLOCK = meta['BLOCK'] row = tl.program_id(0) cols = tl.arange(0, BLOCK) idx = tl.load(IDX + row) # pointers to probs PROBS = PROBS + row * N + cols # We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k] # and we have -log(p[k]) stored in PROBS, so this is easy probs = -tl.load(PROBS, mask=cols < N, other=float('inf')) probs = tl.exp(probs.to(tl.float32)) delta = cols == idx # write result in-place in PROBS dout = tl.load(DPROBS + row) din = (probs - delta) * dout tl.store(PROBS, din.to(tl.float16), mask=cols < N)
def _softmax( Y, #output pointer X, #input pointer stride_xm, #stride tells us which to normalize stride_ym, # M, #num of rows N, #num of cols **meta): ## Note : Here stride is basically the amount you have to ## move in memory to get to the next element ## Ex: so if x is 2D, lets say x = torch.randn(10, 20), ## then x[1, 1].data_ptr() = x[1, 0].data_ptr() + x.stride(1) # 1 program id per row index m = tl.program_id(0) # col indices # here BLOCK is the smallest power of two greater than `N` # In triton, the block size always has to be the power of 2 # in general powers of 2 nicely fit / align with cache lines n = tl.arange(0, meta['BLOCK']) # the memory address of all the elements # that we want to load can be computed as follows # m * stride_xm + n is how forward we want to move for the next pointer X = X + m * stride_xm + n # x is the vector of values loaded. # - float(inf) is so we dont go over the edge # if we do , penalize it and make it come back x = tl.load(X, mask=n < N, other=-float('inf')) # Substract maximum for numerical stability z = x - tl.max(x, axis=0) # Note that exponentials in Triton are fast # but approximate (i.e., think __expf in CUDA) num = tl.exp(z) denom = tl.sum(num, axis=0) y = num / denom # Write back to Y Y = Y + m * stride_ym + n tl.store(Y, y, mask=n < N)
def _add( X, # *Pointer* to first input vector Y, # *Pointer* to second input vector Z, # *Pointer* to output vector N, # Size of the vector **meta # Optional meta-parameters for the kernel ): pid = tl.program_id(0) # Create an offset for the blocks of pointers to be # processed by this program instance offsets = pid * meta['BLOCK'] + tl.arange(0, meta['BLOCK']) # Create a mask to guard memory operations against # out-of-bounds accesses mask = offsets < N # Load x x = tl.load(X + offsets, mask=mask) y = tl.load(Y + offsets, mask=mask) # Write back x + y z = x + y tl.store(Z + offsets, z)
def _forward(LOGITS, PROBS, IDX, LOSS, N, **meta): BLOCK = meta['BLOCK'] row = tl.program_id(0) cols = tl.arange(0, BLOCK) idx = tl.load(IDX + row) # pointers to logit and probs LOGITS = LOGITS + row * N + cols WRIT_PROBS = PROBS + row * N + cols READ_PROBS = PROBS + row * N + idx # write-back negative log-probs logits = tl.load(LOGITS, mask=cols < N, other=-float('inf')) logits = logits.to(tl.float32) logits = logits - tl.max(logits, 0) probs = tl.log(tl.sum(tl.exp(logits), 0)) - logits tl.store(WRIT_PROBS, probs, mask=cols < N) # There is a bug in the compiler, which fails to insert a barrier here. # We add it explicitly for now. Will be fixed soon. tl.debug_barrier() # write-back loss probs = tl.load(READ_PROBS) tl.store(LOSS + row, probs)
def _matmul(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, **META): # extract meta-parameters BLOCK_M = META['BLOCK_M'] BLOCK_N = META['BLOCK_N'] BLOCK_K = META['BLOCK_K'] GROUP_M = 8 # matrix multiplication pid = tl.program_id(0) grid_m = (M + BLOCK_M - 1) // BLOCK_M grid_n = (N + BLOCK_N - 1) // BLOCK_N # re-order program ID for better L2 performance width = GROUP_M * grid_n group_id = pid // width group_size = min(grid_m - group_id * GROUP_M, GROUP_M) pid_m = group_id * GROUP_M + (pid % group_size) pid_n = (pid % width) // (group_size) # do matrix multiplication rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) rk = tl.arange(0, BLOCK_K) A = A + (rm[:, None] * stride_am + rk[None, :] * stride_ak) B = B + (rk[:, None] * stride_bk + rn[None, :] * stride_bn) acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k in range(K, 0, -BLOCK_K): a = tl.load(A) b = tl.load(B) acc += tl.dot(a, b) A += BLOCK_K * stride_ak B += BLOCK_K * stride_bk # triton can accept arbitrary activation function # via metaparameters! if META['ACTIVATION']: acc = META['ACTIVATION'](acc) # rematerialize rm and rn to save registers rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) mask = (rm[:, None] < M) & (rn[None, :] < N) tl.store(C, acc, mask=mask)
def _add( X, # *pointer* to input vector 1 Y, # *pointer * to input vector 2 Z, # *pointer* to output vector N, # Size of the vector **meta #Optional meta parameters for the kernel # In meta, block size and compile time constants etc parameters ): #Roughly the start of the block pid = tl.program_id(0) #We dont worry about threads here. We are writing program for #whole block for now # Create offset for the block of pointers to be processed # by the program instance #Read all the block of pointers offsets = pid * meta['BLOCK'] + tl.arange(0, meta['BLOCK']) #Create a mask to guard memory operations against #out-of-bound accesses #mask defines the shape. It defines what to load and what to #leave undefined mask = offsets < N #Load x #scale + arange -> yields block of pointers load_x_ptrs = X + offsets load_y_ptrs = Y + offsets x = tl.load(load_x_ptrs, mask=mask) # x <- tensor of data y = tl.load(load_y_ptrs, mask=mask) #Write x + y z = x + y tl.store(Z + offsets, z)
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, **meta): # The rows of the softmax are independent, so we parallelize across those row_idx = tl.program_id(0) BLOCK_SIZE = meta['BLOCK_SIZE'] # The stride represents how much we need to increase the pointer to advance 1 row row_start_ptr = input_ptr + row_idx * input_row_stride # The block size is the next power of two greater than n_cols, so we can fit each # row in a single block col_offsets = tl.arange(0, BLOCK_SIZE) input_ptrs = row_start_ptr + col_offsets # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')) # Substract maximum for numerical stability row_minus_max = row - tl.max(row, axis=0) # Note that exponentials in Triton are fast but approximate (i.e., think __expf in CUDA) numerator = tl.exp(row_minus_max) denominator = tl.sum(numerator, axis=0) softmax_output = numerator / denominator # Write back output to DRAM output_row_start_ptr = output_ptr + row_idx * output_row_stride output_ptrs = output_row_start_ptr + col_offsets tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)
def _kernel(A, B, C, stride_za, stride_ha, stride_ma, stride_ka, stride_zb, stride_hb, stride_kb, stride_nb, stride_zc, stride_hc, stride_mc, stride_nc, DS0, DS1, SDD_K, SDD_off_width, lut, locks, nlocks, **meta): TM = meta['TM'] TN = meta['TN'] TK = meta['TK'] TZ = meta['TZ'] BLOCK = meta['BLOCK'] #------------# #- Prologue -# #------------# pid0 = tl.program_id(0) pid1 = tl.program_id(1) pidz = tl.program_id(2) if meta['SDD']: pid1 = pid1 + SDD_off_width blockidm = tl.arange(0, TM) // BLOCK blockidn = tl.arange(0, TN) // BLOCK offlutm = blockidm * (TN // BLOCK) * 4 offlutn = blockidn * 4 header = lut + pid1 * (TM // BLOCK) * (TN // BLOCK) * 4 z = tl.load(header + 0) i = tl.load(header + 1 + offlutm) j = tl.load(header + 2 + offlutn) AS1 = SDD_K // TZ lockid = tl.where(TZ > 1, 1, 0) offka = pid0 * AS1 offkb = pid0 * AS1 offmc = 0 offnc = 0 offpa = 0 offpb = 0 maxid = TZ offhc = 0 offha = z offhb = z ram = i * BLOCK + (tl.arange(0, TM) % BLOCK) rbn = j * BLOCK + (tl.arange(0, TN) % BLOCK) else: header = lut + pid0 * 6 offset = tl.load(header + 0) AS1 = tl.load(header + 1) column = tl.load(header + 2) depth = tl.load(header + 3) lockid = tl.load(header + 4) maxid = tl.load(header + 5) pinc = lut + offset offhc = depth if meta['DSD']: # output offset offnc = pid1 * TN offmc = column * TM offpc = 0 # dense input offset offnb = pid1 * TN offkb = tl.load(pinc) offkb = tl.multiple_of(offkb, 8) # compiler hint offpb = 0 # sparse input offset offma = 0 offka = 0 offpa = tl.load(pinc + 1) offpa = tl.multiple_of(offpa, 8) # compiler hint offpa = offpa * BLOCK * BLOCK offha = 0 offhb = depth else: # output offset offmc = pid1 * TM offnc = column * TN offpc = 0 # dense input offset offma = pid1 * TM offka = tl.load(pinc) offka = tl.multiple_of(offka, 8) # compiler hint offpa = 0 # sparse input offset offnb = 0 offkb = 0 offpb = tl.load(pinc + 1) offpb = tl.multiple_of(offpb, 8) # compiler hint offpb = offpb * BLOCK * BLOCK offha = depth offhb = 0 ram = offma + tl.arange(0, TM) rbn = offnb + tl.arange(0, TN) # initialize a, b pointers rka = offka + tl.arange(0, TK) rkb = offkb + tl.arange(0, TK) pa = A + pidz * stride_za + offha * stride_ha + offpa + ram[:, None] * stride_ma + rka[ None, :] * stride_ka pb = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn[ None, :] * stride_nb + rkb[:, None] * stride_kb if meta['DDS']: checkam = ram[:, None] < DS0 else: checkam = AS1 > 0 if meta['DSD']: checkbn = rbn[None, :] < DS0 else: checkbn = AS1 > 0 a = tl.load(pa, mask=checkam, other=0.) b = tl.load(pb, mask=checkbn, other=0.) ## ---------------- ## ## Inner Loop ## ## ---------------- ## acc = tl.zeros((TM, TN), dtype=tl.float32) for k in range(AS1, 0, -TK): acc += tl.dot(a, b) if meta['SDD']: inc_a = TK * stride_ka inc_b = TK * stride_kb else: pinc += 2 if meta['DSD']: inc_b = tl.load(pinc) inc_a = tl.load(pinc + 1) inc_b = tl.multiple_of(inc_b, 8) inc_a = tl.multiple_of(inc_a, 8) inc_b = inc_b * stride_kb if meta['DDS']: inc_a = tl.load(pinc) inc_b = tl.load(pinc + 1) inc_a = tl.multiple_of(inc_a, 8) inc_b = tl.multiple_of(inc_b, 8) inc_a = inc_a * stride_ka pa += inc_a pb += inc_b # pre-fetch checkak = k > TK checkbk = k > TK checka = checkam & checkak checkb = checkbn & checkbk a = tl.load(pa, mask=checka) b = tl.load(pb, mask=checkb) c = acc.to(C.dtype.element_ty) if meta['SDD']: checkc = True rr_blockidm = tl.arange(0, TM) // BLOCK rr_blockidn = tl.arange(0, TN) // BLOCK rr_offlutm = rr_blockidm * (TN // BLOCK) * 4 rr_offlutn = rr_blockidn * 4 off_bkid = 3 + rr_offlutm[:, None] + rr_offlutn[None, :] bkid = tl.load(header + off_bkid) offpc = bkid * BLOCK * BLOCK rcm = tl.arange(0, TM) % BLOCK rcn = tl.arange(0, TN) % BLOCK else: rcm = offmc + tl.arange(0, TM) rcn = offnc + tl.arange(0, TN) if meta['DSD']: checkc = rcn[None, :] < DS0 if meta['DDS']: checkc = rcm[:, None] < DS0 pc = C + offpc + offhc * stride_hc + pidz * stride_zc + rcm[:, None] * stride_mc + rcn[ None, :] * stride_nc # write-back directly if lockid == 0: tl.store(pc, c, mask=checkc) # accumulate partial results using spin-locks else: plock = locks + tl.program_id(2) * nlocks * tl.num_programs( 1) + tl.program_id(1) * nlocks + lockid - 1 pcount = plock + tl.num_programs(2) * tl.num_programs(1) * nlocks while tl.atomic_cas(plock, 0, 1) == 1: pass count = tl.load(pcount) if count == 0: tl.store(pc, c, mask=checkc) else: d = tl.load(pc, mask=checkc) tl.store(pc, d + c, mask=checkc) tl.atomic_xchg(pcount, (count + 1) % maxid) tl.atomic_xchg(plock, 0)
def _kernel(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, LOCKS, **META): # extract meta-parameters BLOCK_M = META['BLOCK_M'] BLOCK_N = META['BLOCK_N'] BLOCK_K = META['BLOCK_K'] GROUP_M = META['GROUP_M'] SPLIT_K = META['SPLIT_K'] # matrix multiplication pid = tl.program_id(0) pid_z = tl.program_id(1) grid_m = (M + BLOCK_M - 1) // BLOCK_M grid_n = (N + BLOCK_N - 1) // BLOCK_N # re-order program ID for better L2 performance width = GROUP_M * grid_n group_id = pid // width group_size = min(grid_m - group_id * GROUP_M, GROUP_M) pid_m = group_id * GROUP_M + (pid % group_size) pid_n = (pid % width) // (group_size) # do matrix multiplication rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) rk = tl.arange(0, BLOCK_K) # pointers K = K // SPLIT_K A = A + (pid_z * K * stride_ak + rm[:, None] * stride_am + rk[None, :] * stride_ak) B = B + (pid_z * K * stride_bk + rk[:, None] * stride_bk + rn[None, :] * stride_bn) acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k in range(K, 0, -BLOCK_K): if META['EVEN_K']: a = tl.load(A) b = tl.load(B) else: a = tl.load(A, mask=rk[None, :] < k, other=0.) b = tl.load(B, mask=rk[:, None] < k, other=0.) acc += tl.dot(a, b) A += BLOCK_K * stride_ak B += BLOCK_K * stride_bk acc = acc.to(tl.float16) # rematerialize rm and rn to save registers rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) mask = (rm < M)[:, None] & (rn < N)[None, :] # handles write-back with reduction-splitting if SPLIT_K == 1: tl.store(C, acc, mask=mask) else: LOCKS = LOCKS + tl.program_id(0) COUNT = LOCKS + tl.num_programs(0) while tl.atomic_cas(LOCKS, 0, 1) == 1: pass count = tl.load(COUNT) if count == 0: tl.store(C, acc, mask=mask) else: curr = tl.load(C, mask=mask, other=0.) tl.store(C, acc + curr, mask=mask) tl.atomic_xchg(COUNT, (count + 1) % SPLIT_K) tl.atomic_xchg(LOCKS, 0)
def kernel(X, Z, **meta): pid = tl.program_id(0) old = tl.atomic_add(X, pid) tl.store(Z + pid, old)
def matmul_kernel( # Pointers to matrices a_ptr, b_ptr, c_ptr, # Matrix dimensions M, N, K, # The stride variables represent how much to increase the ptr by when moving by 1 # element in a particular dimension. E.g. stride_am is how much to increase a_ptr # by to get the element one row down (A has M rows) stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, # Meta-parameters **meta, ): """Kernel for computing the matmul C = A x B. A has shape (M, K), B has shape (K, N) and C has shape (M, N) """ # extract meta-parameters BLOCK_SIZE_M = meta['BLOCK_SIZE_M'] BLOCK_SIZE_N = meta['BLOCK_SIZE_N'] BLOCK_SIZE_K = meta['BLOCK_SIZE_K'] GROUP_SIZE_M = 8 # ----------------------------------------------------------- # Map program ids `pid` to the block of C it should compute. # This is done in a grouped ordering to promote L2 data reuse # See above `L2 Cache Optimizations` section for details pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) num_pid_in_group = GROUP_SIZE_M * num_pid_n group_id = pid // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) pid_m = first_pid_m + (pid % group_size_m) pid_n = (pid % num_pid_in_group) // group_size_m # ---------------------------------------------------------- # Create pointers for the first blocks of A and B. # We will advance this pointer as we move in the K direction # and accumulate # a_ptrs is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers # b_ptrs is a block of [BLOCK_SIZE_K, BLOCK_SIZE_n] pointers # see above `Pointer Arithmetics` section for details offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + (offs_am[:, None]*stride_am + offs_k [None, :]*stride_ak) b_ptrs = b_ptr + (offs_k [:, None]*stride_bk + offs_bn[None, :]*stride_bn) # ----------------------------------------------------------- # Iterate to compute a block of the C matrix # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block # of fp32 values for higher accuracy. # `accumulator` will be converted back to fp16 after the loop accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, K, BLOCK_SIZE_K): # Note that for simplicity, we don't apply a mask here. # This means that if K is not a multiple of BLOCK_SIZE_K, # this will access out-of-bounds memory and produce an # error or (worse!) incorrect results. a = tl.load(a_ptrs) b = tl.load(b_ptrs) # We accumulate along the K dimension accumulator += tl.dot(a, b) # Advance the ptrs to the next K block a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk # you can fuse arbitrary activation functions here # while the accumulator is still in FP32 ! if meta['ACTIVATION']: accumulator = meta['ACTIVATION'](accumulator) c = accumulator.to(tl.float16) # ----------------------------------------------------------- # Write back the block of the output matrix C offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) tl.store(c_ptrs, c, mask=c_mask)
def kernel(X, Z, **meta): pid = tl.program_id(0) x = tl.load(X + pid) old = GENERATE_TEST_HERE