def _forward(X, scale, LUT, RPE, KP_M, ATTN_M, sizemax, stride_zx, stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm, **meta): TN = meta['TN'] BLOCK = meta['BLOCK'] pidhm = tl.program_id(0) pidz = tl.program_id(1) # create index ranges rxm = pidhm % BLOCK rbm = pidhm // BLOCK rxn = tl.arange(0, TN) % BLOCK rbn = tl.arange(0, TN) // BLOCK # extract information from LUT header = LUT + rbm * 2 size = tl.load(header + 0) offset = tl.load(header + 1) check = rbn < size rbmn = tl.where(check, rbn, size - 1) # block id and column id blockid = tl.load(LUT + offset + rbmn * 4 + 0) columnid = tl.load(LUT + offset + rbmn * 4 + 1) rowid = tl.load(LUT + offset + rbmn * 4 + 2) headid = tl.load(LUT + offset + rbmn * 4 + 3) # pointers to X px = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn x = tl.load(px, mask=check, other=-float('inf')) x = x.to(tl.float32) # apply scale if meta['APPLY_SCALE']: x = x * scale # apply RPE if meta['APPLY_RPE']: prpe = RPE + pidz * stride_zrpe + headid * stride_hrpe + columnid * BLOCK + rowid * BLOCK * stride_srpe + rxm * stride_srpe + rxn rpe = tl.load(prpe, mask=check, other=0) x = x + rpe # apply key-padding mask if meta['APPLY_KP_MASK']: pkp_m = KP_M + pidz * stride_zkpm + columnid * BLOCK + rxn kp_m = tl.load(pkp_m, mask=check, other=-float('inf')) if meta['KP_MASK_MUL']: kp_m = tl.where(kp_m == 0, -float('inf'), 0.) x = x + kp_m # apply attention mask if meta['APPLY_ATTN_MASK']: pattn_m = ATTN_M + columnid * BLOCK + rowid * BLOCK * stride_zattnm + rxm * stride_zattnm + rxn attn_m = tl.load(pattn_m, mask=check, other=-float('inf')) if meta['ATTN_MASK_MUL']: attn_m = tl.where(attn_m == 0, -float('inf'), 0.) x = x + attn_m # computation x = tl.softmax(x) tl.store(px, x, mask=check)
def _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, **meta): pidhm = tl.program_id(0) pidz = tl.program_id(1) TN = meta['TN'] BLOCK = meta['BLOCK'] # create index ranges rxm = pidhm % BLOCK rbm = pidhm // BLOCK rxn = tl.arange(0, TN) % BLOCK rbn = tl.arange(0, TN) // BLOCK # extract information from look-up table header = LUT + rbm * 2 size = tl.load(header + 0) offset = tl.load(header + 1) # bounds checking on lut check = rbn < size rbmn = tl.where(check, rbn, size - 1) # initialize pointers to block-sparse input blockid = tl.load(LUT + offset + rbmn * 4) X = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn DX = DX + pidz * stride_zdx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn # compute fused softmax backward x = tl.load(X, mask=check, other=0) dx = tl.load(DX, mask=check, other=0) x = x.to(tl.float32) dx = dx.to(tl.float32) y = x * (dx - tl.sum(x * dx, 0)) * scale tl.store(DX, y, mask=check)
def leaky_relu(x): return tl.where(x >= 0, x, 0.01 * x)
def _kernel(A, B, C, stride_za, stride_ha, stride_ma, stride_ka, stride_zb, stride_hb, stride_kb, stride_nb, stride_zc, stride_hc, stride_mc, stride_nc, DS0, DS1, SDD_K, SDD_off_width, lut, locks, nlocks, **meta): TM = meta['TM'] TN = meta['TN'] TK = meta['TK'] TZ = meta['TZ'] BLOCK = meta['BLOCK'] #------------# #- Prologue -# #------------# pid0 = tl.program_id(0) pid1 = tl.program_id(1) pidz = tl.program_id(2) if meta['SDD']: pid1 = pid1 + SDD_off_width blockidm = tl.arange(0, TM) // BLOCK blockidn = tl.arange(0, TN) // BLOCK offlutm = blockidm * (TN // BLOCK) * 4 offlutn = blockidn * 4 header = lut + pid1 * (TM // BLOCK) * (TN // BLOCK) * 4 z = tl.load(header + 0) i = tl.load(header + 1 + offlutm) j = tl.load(header + 2 + offlutn) AS1 = SDD_K // TZ lockid = tl.where(TZ > 1, 1, 0) offka = pid0 * AS1 offkb = pid0 * AS1 offmc = 0 offnc = 0 offpa = 0 offpb = 0 maxid = TZ offhc = 0 offha = z offhb = z ram = i * BLOCK + (tl.arange(0, TM) % BLOCK) rbn = j * BLOCK + (tl.arange(0, TN) % BLOCK) else: header = lut + pid0 * 6 offset = tl.load(header + 0) AS1 = tl.load(header + 1) column = tl.load(header + 2) depth = tl.load(header + 3) lockid = tl.load(header + 4) maxid = tl.load(header + 5) pinc = lut + offset offhc = depth if meta['DSD']: # output offset offnc = pid1 * TN offmc = column * TM offpc = 0 # dense input offset offnb = pid1 * TN offkb = tl.load(pinc) offkb = tl.multiple_of(offkb, 8) # compiler hint offpb = 0 # sparse input offset offma = 0 offka = 0 offpa = tl.load(pinc + 1) offpa = tl.multiple_of(offpa, 8) # compiler hint offpa = offpa * BLOCK * BLOCK offha = 0 offhb = depth else: # output offset offmc = pid1 * TM offnc = column * TN offpc = 0 # dense input offset offma = pid1 * TM offka = tl.load(pinc) offka = tl.multiple_of(offka, 8) # compiler hint offpa = 0 # sparse input offset offnb = 0 offkb = 0 offpb = tl.load(pinc + 1) offpb = tl.multiple_of(offpb, 8) # compiler hint offpb = offpb * BLOCK * BLOCK offha = depth offhb = 0 ram = offma + tl.arange(0, TM) rbn = offnb + tl.arange(0, TN) # initialize a, b pointers rka = offka + tl.arange(0, TK) rkb = offkb + tl.arange(0, TK) pa = A + pidz * stride_za + offha * stride_ha + offpa + ram[:, None] * stride_ma + rka[ None, :] * stride_ka pb = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn[ None, :] * stride_nb + rkb[:, None] * stride_kb if meta['DDS']: checkam = ram[:, None] < DS0 else: checkam = AS1 > 0 if meta['DSD']: checkbn = rbn[None, :] < DS0 else: checkbn = AS1 > 0 a = tl.load(pa, mask=checkam, other=0.) b = tl.load(pb, mask=checkbn, other=0.) ## ---------------- ## ## Inner Loop ## ## ---------------- ## acc = tl.zeros((TM, TN), dtype=tl.float32) for k in range(AS1, 0, -TK): acc += tl.dot(a, b) if meta['SDD']: inc_a = TK * stride_ka inc_b = TK * stride_kb else: pinc += 2 if meta['DSD']: inc_b = tl.load(pinc) inc_a = tl.load(pinc + 1) inc_b = tl.multiple_of(inc_b, 8) inc_a = tl.multiple_of(inc_a, 8) inc_b = inc_b * stride_kb if meta['DDS']: inc_a = tl.load(pinc) inc_b = tl.load(pinc + 1) inc_a = tl.multiple_of(inc_a, 8) inc_b = tl.multiple_of(inc_b, 8) inc_a = inc_a * stride_ka pa += inc_a pb += inc_b # pre-fetch checkak = k > TK checkbk = k > TK checka = checkam & checkak checkb = checkbn & checkbk a = tl.load(pa, mask=checka) b = tl.load(pb, mask=checkb) c = acc.to(C.dtype.element_ty) if meta['SDD']: checkc = True rr_blockidm = tl.arange(0, TM) // BLOCK rr_blockidn = tl.arange(0, TN) // BLOCK rr_offlutm = rr_blockidm * (TN // BLOCK) * 4 rr_offlutn = rr_blockidn * 4 off_bkid = 3 + rr_offlutm[:, None] + rr_offlutn[None, :] bkid = tl.load(header + off_bkid) offpc = bkid * BLOCK * BLOCK rcm = tl.arange(0, TM) % BLOCK rcn = tl.arange(0, TN) % BLOCK else: rcm = offmc + tl.arange(0, TM) rcn = offnc + tl.arange(0, TN) if meta['DSD']: checkc = rcn[None, :] < DS0 if meta['DDS']: checkc = rcm[:, None] < DS0 pc = C + offpc + offhc * stride_hc + pidz * stride_zc + rcm[:, None] * stride_mc + rcn[ None, :] * stride_nc # write-back directly if lockid == 0: tl.store(pc, c, mask=checkc) # accumulate partial results using spin-locks else: plock = locks + tl.program_id(2) * nlocks * tl.num_programs( 1) + tl.program_id(1) * nlocks + lockid - 1 pcount = plock + tl.num_programs(2) * tl.num_programs(1) * nlocks while tl.atomic_cas(plock, 0, 1) == 1: pass count = tl.load(pcount) if count == 0: tl.store(pc, c, mask=checkc) else: d = tl.load(pc, mask=checkc) tl.store(pc, d + c, mask=checkc) tl.atomic_xchg(pcount, (count + 1) % maxid) tl.atomic_xchg(plock, 0)
def sigmoid(x): ret_true = 1 / (1 + tl.exp(-x)) ret_false = tl.exp(x) / (1 + tl.exp(x)) return tl.where(x >= 0, ret_true, ret_false)