Example #1
0
def _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, **meta):
    pidhm = tl.program_id(0)
    pidz = tl.program_id(1)
    TN = meta['TN']
    BLOCK = meta['BLOCK']
    # create index ranges
    rxm = pidhm % BLOCK
    rbm = pidhm // BLOCK
    rxn = tl.arange(0, TN) % BLOCK
    rbn = tl.arange(0, TN) // BLOCK
    # extract information from look-up table
    header = LUT + rbm * 2
    size = tl.load(header + 0)
    offset = tl.load(header + 1)
    # bounds checking on lut
    check = rbn < size
    rbmn = tl.where(check, rbn, size - 1)
    # initialize pointers to block-sparse input
    blockid = tl.load(LUT + offset + rbmn * 4)
    X = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn
    DX = DX + pidz * stride_zdx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn
    # compute fused softmax backward
    x = tl.load(X, mask=check, other=0)
    dx = tl.load(DX, mask=check, other=0)
    x = x.to(tl.float32)
    dx = dx.to(tl.float32)
    y = x * (dx - tl.sum(x * dx, 0)) * scale
    tl.store(DX, y, mask=check)
def _softmax(
        Y,  #output pointer 
        X,  #input pointer
        stride_xm,  #stride tells us which to normalize
        stride_ym,  #
        M,  #num of rows
        N,  #num of cols
        **meta):

    ## Note : Here stride is basically the amount you have to
    ## move in memory to get to the next element
    ## Ex: so if x is 2D, lets say x = torch.randn(10, 20),
    ##  then x[1, 1].data_ptr() = x[1, 0].data_ptr() + x.stride(1)

    # 1 program id per row index
    m = tl.program_id(0)

    # col indices
    # here BLOCK is the smallest power of two greater than `N`
    # In triton, the block size always has to be the power of 2
    # in general powers of 2 nicely fit / align with cache lines

    n = tl.arange(0, meta['BLOCK'])

    # the memory address of all the elements
    # that we want to load can be computed as follows
    #   m * stride_xm + n is how forward we want to move for the next pointer

    X = X + m * stride_xm + n

    # x is the vector of values loaded.
    # - float(inf) is so we dont go over the edge
    # if we do , penalize it and make it come back
    x = tl.load(X, mask=n < N, other=-float('inf'))

    # Substract maximum for numerical stability
    z = x - tl.max(x, axis=0)
    # Note that exponentials in Triton are fast
    # but approximate (i.e., think __expf in CUDA)
    num = tl.exp(z)
    denom = tl.sum(num, axis=0)
    y = num / denom
    # Write back to Y
    Y = Y + m * stride_ym + n
    tl.store(Y, y, mask=n < N)
Example #3
0
def _forward(LOGITS, PROBS, IDX, LOSS, N, **meta):
    BLOCK = meta['BLOCK']
    row = tl.program_id(0)
    cols = tl.arange(0, BLOCK)
    idx = tl.load(IDX + row)
    # pointers to logit and probs
    LOGITS = LOGITS + row * N + cols
    WRIT_PROBS = PROBS + row * N + cols
    READ_PROBS = PROBS + row * N + idx
    # write-back negative log-probs
    logits = tl.load(LOGITS, mask=cols < N, other=-float('inf'))
    logits = logits.to(tl.float32)
    logits = logits - tl.max(logits, 0)
    probs = tl.log(tl.sum(tl.exp(logits), 0)) - logits
    tl.store(WRIT_PROBS, probs, mask=cols < N)
    # There is a bug in the compiler, which fails to insert a barrier here.
    # We add it explicitly for now. Will be fixed soon.
    tl.debug_barrier()
    # write-back loss
    probs = tl.load(READ_PROBS)
    tl.store(LOSS + row, probs)
Example #4
0
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride,
                   n_cols, **meta):
    # The rows of the softmax are independent, so we parallelize across those
    row_idx = tl.program_id(0)
    BLOCK_SIZE = meta['BLOCK_SIZE']
    # The stride represents how much we need to increase the pointer to advance 1 row
    row_start_ptr = input_ptr + row_idx * input_row_stride

    # The block size is the next power of two greater than n_cols, so we can fit each
    # row in a single block
    col_offsets = tl.arange(0, BLOCK_SIZE)
    input_ptrs = row_start_ptr + col_offsets
    # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
    row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf'))
    # Substract maximum for numerical stability
    row_minus_max = row - tl.max(row, axis=0)
    # Note that exponentials in Triton are fast but approximate (i.e., think __expf in CUDA)
    numerator = tl.exp(row_minus_max)
    denominator = tl.sum(numerator, axis=0)
    softmax_output = numerator / denominator
    # Write back output to DRAM
    output_row_start_ptr = output_ptr + row_idx * output_row_stride
    output_ptrs = output_row_start_ptr + col_offsets
    tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)