def _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, **meta): pidhm = tl.program_id(0) pidz = tl.program_id(1) TN = meta['TN'] BLOCK = meta['BLOCK'] # create index ranges rxm = pidhm % BLOCK rbm = pidhm // BLOCK rxn = tl.arange(0, TN) % BLOCK rbn = tl.arange(0, TN) // BLOCK # extract information from look-up table header = LUT + rbm * 2 size = tl.load(header + 0) offset = tl.load(header + 1) # bounds checking on lut check = rbn < size rbmn = tl.where(check, rbn, size - 1) # initialize pointers to block-sparse input blockid = tl.load(LUT + offset + rbmn * 4) X = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn DX = DX + pidz * stride_zdx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn # compute fused softmax backward x = tl.load(X, mask=check, other=0) dx = tl.load(DX, mask=check, other=0) x = x.to(tl.float32) dx = dx.to(tl.float32) y = x * (dx - tl.sum(x * dx, 0)) * scale tl.store(DX, y, mask=check)
def _softmax( Y, #output pointer X, #input pointer stride_xm, #stride tells us which to normalize stride_ym, # M, #num of rows N, #num of cols **meta): ## Note : Here stride is basically the amount you have to ## move in memory to get to the next element ## Ex: so if x is 2D, lets say x = torch.randn(10, 20), ## then x[1, 1].data_ptr() = x[1, 0].data_ptr() + x.stride(1) # 1 program id per row index m = tl.program_id(0) # col indices # here BLOCK is the smallest power of two greater than `N` # In triton, the block size always has to be the power of 2 # in general powers of 2 nicely fit / align with cache lines n = tl.arange(0, meta['BLOCK']) # the memory address of all the elements # that we want to load can be computed as follows # m * stride_xm + n is how forward we want to move for the next pointer X = X + m * stride_xm + n # x is the vector of values loaded. # - float(inf) is so we dont go over the edge # if we do , penalize it and make it come back x = tl.load(X, mask=n < N, other=-float('inf')) # Substract maximum for numerical stability z = x - tl.max(x, axis=0) # Note that exponentials in Triton are fast # but approximate (i.e., think __expf in CUDA) num = tl.exp(z) denom = tl.sum(num, axis=0) y = num / denom # Write back to Y Y = Y + m * stride_ym + n tl.store(Y, y, mask=n < N)
def _forward(LOGITS, PROBS, IDX, LOSS, N, **meta): BLOCK = meta['BLOCK'] row = tl.program_id(0) cols = tl.arange(0, BLOCK) idx = tl.load(IDX + row) # pointers to logit and probs LOGITS = LOGITS + row * N + cols WRIT_PROBS = PROBS + row * N + cols READ_PROBS = PROBS + row * N + idx # write-back negative log-probs logits = tl.load(LOGITS, mask=cols < N, other=-float('inf')) logits = logits.to(tl.float32) logits = logits - tl.max(logits, 0) probs = tl.log(tl.sum(tl.exp(logits), 0)) - logits tl.store(WRIT_PROBS, probs, mask=cols < N) # There is a bug in the compiler, which fails to insert a barrier here. # We add it explicitly for now. Will be fixed soon. tl.debug_barrier() # write-back loss probs = tl.load(READ_PROBS) tl.store(LOSS + row, probs)
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, **meta): # The rows of the softmax are independent, so we parallelize across those row_idx = tl.program_id(0) BLOCK_SIZE = meta['BLOCK_SIZE'] # The stride represents how much we need to increase the pointer to advance 1 row row_start_ptr = input_ptr + row_idx * input_row_stride # The block size is the next power of two greater than n_cols, so we can fit each # row in a single block col_offsets = tl.arange(0, BLOCK_SIZE) input_ptrs = row_start_ptr + col_offsets # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')) # Substract maximum for numerical stability row_minus_max = row - tl.max(row, axis=0) # Note that exponentials in Triton are fast but approximate (i.e., think __expf in CUDA) numerator = tl.exp(row_minus_max) denominator = tl.sum(numerator, axis=0) softmax_output = numerator / denominator # Write back output to DRAM output_row_start_ptr = output_ptr + row_idx * output_row_stride output_ptrs = output_row_start_ptr + col_offsets tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)