Exemple #1
def _get_oned_copy_kernel(dtype, shape):
    copy = r"""
__global__ void copy_oned(%(type)s* out, const %(type)s* in, int dim, long long src_str,
                          long long dst_str)
    int tid_x = threadIdx.x;
    int idx = blockIdx.x;

    idx = (idx << 5) + tid_x;

    const %(type)s* in0 = in + (src_str * idx);
    %(type)s* out0 = out + (dst_str * idx);

    if(idx < dim) *out0 = *in0;
    code = copy % dict(
        type=_get_register_type(dtype, memory=True)

    # print code
    module = SourceModule(code)
    kernel = module.get_function("copy_oned")

    kernel.grid = (_ceil_div(shape[0], 32), 1, 1)
    kernel.block = (32, 1, 1)
    kernel.args = (shape[0], )

    return kernel
def _get_lut_bprop_kernel(dtype, in_dtype, deterministic=False):
    Builds the bprop kernel for lookup table layers based on templated code.
    If the deterministic version is requested, an index buffer must be passed
    as an argument. This index buffer re-orders items in the input tensor
    so that word_ids are sorted. This is required since we need to be sure that
    each thread only updates weights for one word id.

        dtype (np.dtype): The data which the kernel will operate on.
        deterministic (boolean): Builds the deterministic kernel when this is
            set to True.
    if not deterministic:
        code = r"""
__global__ void lut_bprop(
    %(in_dtype)s* inputs, %(type)s* dW, %(type)s* errors, const int nin,
    const int embedding_dim, const int vocab_size, const int pad_idx)
    const int tid  = threadIdx.x;
    const int bid  = blockIdx.x;

    int word_id = inputs[bid];
    int error_row = bid * embedding_dim;
    int output_row = word_id * embedding_dim;

    if(word_id != pad_idx)
        for(int i = tid; i < embedding_dim; i += blockDim.x)
            atomicAdd(&dW[output_row + i], errors[error_row + i]);

        code = code % {"type": _get_register_type(dtype)}

        module = SourceModule(code, options=["--use_fast_math"])
        kernel = module.get_function("lut_bprop")
        code = r"""

__global__ void lut_bprop(
    %(in_dtype)s* inputs, int* index_buffer, %(type)s* dW, %(type)s* errors,
    const int nin, const int embedding_dim, const int vocab_size,
    const int pad_idx %(stats_args)s)
    const int tid  = threadIdx.x;
    const int bid  = blockIdx.x;

    int index_position = bid;
    int index = index_buffer[index_position];
    int word_id = inputs[index] %(compute_input)s;
    int intermediate_max = 0;

    if((bid == 0 || word_id != (inputs[index_buffer[bid - 1]] %(compute_input)s)) &&
        word_id != pad_idx)
        int output_row = word_id * embedding_dim;

        do {
            int error_row = index * embedding_dim;

            for(int i = tid; i < embedding_dim; i += blockDim.x)
            if(index_position == gridDim.x)
            index = index_buffer[index_position];
        } while((inputs[index] %(compute_input)s) == word_id);
        code %= _configure_template_vals_bprop(in_dtype, dtype)

        module = SourceModule(code, options=["--use_fast_math"])
        kernel = module.get_function("lut_bprop")
        kernel.prepare("PPPPIIIi" + flex_sig_bprop(in_dtype.str[1]))

    kernel.name = lut_bprop_kernel_name
    return kernel
def _get_sorting_kernel(kernel_id, block_size, in_dtype):
    Builds kernels used for sorting inputs. There are several kernels here
    corresponding to the steps in the algorithm. The algorithm works by
    determining the sorted position for each input item. This is done with
    a bucket sort algorithm, where each word_id is a bucket. The first step
    determines the size of each bucket (number of occurences of each word_id).
    Next, a prefix some is computed over the list of bucket sizes to find
    where each bucket will be placed in the output buffer. Finally, each thread
    places it's index into the correct sorted position based on the bucket
    start index (computed from the prefix sum) and that thread's offset into
    the bucket (which is taken from the output of the atomic add done in the
    first step.)

        kernel_id (Integer): Which step to build the kernel for [0, 4]
        block_size (Integer): Number of threads per block for the prefix sum
    code = r"""
#define THREADS %(threads)s
#define STORE_BLOCKSUM %(store_blocksum)s
__global__ void sort_inputs0(
        %(in_dtype)s* inputs, int* index_buffer, int* offset_buffer, int* word_counts,
        const int vocab_size, const int input_length %(stats_args)s)
    const int tid = threadIdx.x + (blockDim.x * blockIdx.x);
    int word_id;

    if(tid < input_length)
        word_id = inputs[tid] %(compute_input)s;
        offset_buffer[tid] = atomicAdd(&word_counts[word_id], 1);

__device__ void scan(int* buffer, int* blocksum, int global_length)
    const int tid = (threadIdx.x << 1) + 1;
    const int gid = ((threadIdx.x + (blockIdx.x * blockDim.x)) << 1) + 1;

    __shared__ int local_counts[THREADS * 2];
    local_counts[tid] = buffer[gid];
    local_counts[tid - 1] = buffer[gid - 1];

    #pragma unroll
    for(int skip = 1; skip <= THREADS; skip <<= 1)
        int mask = (skip << 1) - 1;
        if((tid & mask) == mask)
            local_counts[tid] += local_counts[tid - skip];


    if(tid == (THREADS * 2 - 1))
        blocksum[blockIdx.x] = local_counts[tid];
        local_counts[tid] = 0;

    #pragma unroll
    for(int skip = THREADS; skip > 0; skip >>= 1)
        int mask = (skip << 1) - 1;
        if((tid & mask) == mask)
            int temp = local_counts[tid - skip];
            local_counts[tid - skip] = local_counts[tid];
            local_counts[tid] += temp;


    if(gid < global_length)
        buffer[gid] = local_counts[tid];
        buffer[gid - 1] = local_counts[tid - 1];

__global__ void sort_inputs1(
        %(in_dtype)s* inputs, int* index_buffer, int* offset_buffer, int* word_counts,
        const int vocab_size, const int input_length %(stats_args)s)
    scan(word_counts, word_counts + vocab_size, vocab_size);

__global__ void sort_inputs2(
        %(in_dtype)s* inputs, int* index_buffer, int* offset_buffer, int* word_counts,
        const int vocab_size, const int input_length %(stats_args)s)
    scan(word_counts + vocab_size, 0, blockDim.x);

__global__ void sort_inputs3(
        %(in_dtype)s* inputs, int* index_buffer, int* offset_buffer, int* word_counts,
        const int vocab_size, const int input_length %(stats_args)s)
    const int gid = (threadIdx.x + (blockIdx.x * blockDim.x)) << 1;

    if(gid < vocab_size)
        word_counts[gid] += word_counts[vocab_size + blockIdx.x];
        word_counts[gid + 1] += word_counts[vocab_size + blockIdx.x];

__global__ void sort_inputs4(
        %(in_dtype)s* inputs, int* index_buffer, int* offset_buffer, int* word_counts,
        const int vocab_size, const int input_length %(stats_args)s)
    const int tid = threadIdx.x + (blockDim.x * blockIdx.x);
    int word_id;

    if(tid < input_length)
        word_id = inputs[tid] %(compute_input)s;
        int sorted_position = word_counts[word_id] + offset_buffer[tid];
        index_buffer[sorted_position] = tid;
    code %= _configure_template_vals_sort(block_size, kernel_id, in_dtype)
    module = SourceModule(code, options=["--use_fast_math"])

    function_name = "sort_inputs" + native_str(kernel_id)
    kernel = module.get_function(function_name)
    kernel.prepare("PPPPII" + flex_sig_sort(in_dtype.str[1]))

    kernel.name = lut_sort_kernel_name
    return kernel
Exemple #4
def _get_copy_transpose_kernel(dtype, shape, axes=None):
    if len(shape) == 1:
        return _get_oned_copy_kernel(dtype, shape)

    src = list(range(len(shape)))
    dst = list(axes)

    src_dim = src[-1]
    dst_dim = dst[-1]

    # If the inner dim is the same for both, no need for shared memory tile
    # Then map the outer source dim to the threadIdx.y values
    if src_dim == dst_dim:
        dst_dim = src[0]
        shared_tile = False
        shared_tile = True

    src_offset = []
    dst_offset = []
    params = []
    values = []
    magic = ""

    # add dims for bounds checking
    for dim in (src_dim, dst_dim):
        params.append("int dim_%s" % dim)

    # collapse src and dst shape by 32
    grid_shape = list(shape)
    grid_shape[src_dim] = _ceil_div(shape[src_dim], 32)
    grid_shape[dst_dim] = _ceil_div(shape[dst_dim], 32)

    # get a src list without dst dim
    src2 = [s for s in src if s != dst_dim]

    # get the name of the first compound index
    blkx_name = compound_idx = "".join(native_str(x) for x in src2)

    # generate the magic number math to extract all indeces
    while len(src2) > 1:

        idx1 = src2[0]
        del src2[0]
        idx2 = "".join(native_str(i) for i in src2)
        div = reduce(mul, (grid_shape[i] for i in src2), 1)

        params.extend(p % idx2 for p in ("int magic_%s", "int shift_%s", "int div_%s"))

        magic += r"""
    int idx_{1} = div64(idx_{0}, magic_{2}, shift_{2});
    int idx_{2} = idx_{0} - idx_{1}*div_{2};
""".format(compound_idx, idx1, idx2)

        compound_idx = idx2

    # Add params for src strides and generate src offset
    # The param values will be added externally
    for s in src:
        params.append("long long src_str_%d" % s)
        src_offset.append("src_str_%d*idx_%d" % (s, s))

    # Add params for dst strides and generate dst offset
    for d in dst:
        params.append("long long dst_str_%d" % d)
        dst_offset.append("dst_str_%d*idx_%d" % (d, d))

    num_strides = len(src) + len(dst)

    if shared_tile:
        copy_transpose = r"""

__global__ void copy_transpose(%(type)s* out, const %(type)s* in, %(params)s)
    __shared__ %(type)s tile[32][33];

    int tid_x = threadIdx.x;
    int tid_y = threadIdx.y;
    int idx_%(blk)s = blockIdx.x;
    int idx_%(dst)s = blockIdx.y;


    idx_%(src)s = (idx_%(src)s << 5) + tid_x;
    idx_%(dst)s = (idx_%(dst)s << 5) + tid_y;

    const %(type)s* in00 = in   + %(src_offset)s;
    const %(type)s* in08 = in00 + src_str_%(dst)s*8;
    const %(type)s* in16 = in08 + src_str_%(dst)s*8;
    const %(type)s* in24 = in16 + src_str_%(dst)s*8;

    bool b%(src)s = idx_%(src)s < dim_%(src)s;

    if (idx_%(dst)s +  0 < dim_%(dst)s && b%(src)s) tile[tid_y +  0][tid_x] = *in00;
    if (idx_%(dst)s +  8 < dim_%(dst)s && b%(src)s) tile[tid_y +  8][tid_x] = *in08;
    if (idx_%(dst)s + 16 < dim_%(dst)s && b%(src)s) tile[tid_y + 16][tid_x] = *in16;
    if (idx_%(dst)s + 24 < dim_%(dst)s && b%(src)s) tile[tid_y + 24][tid_x] = *in24;


    %(type)s val00 = tile[tid_x][tid_y +  0];
    %(type)s val08 = tile[tid_x][tid_y +  8];
    %(type)s val16 = tile[tid_x][tid_y + 16];
    %(type)s val24 = tile[tid_x][tid_y + 24];

    idx_%(src)s += tid_y - tid_x;
    idx_%(dst)s += tid_x - tid_y;

    bool b%(dst)s = idx_%(dst)s < dim_%(dst)s;

    %(type)s* out00 = out   + %(dst_offset)s;
    %(type)s* out08 = out00 + dst_str_%(src)s*8;
    %(type)s* out16 = out08 + dst_str_%(src)s*8;
    %(type)s* out24 = out16 + dst_str_%(src)s*8;

    if (idx_%(src)s +  0 < dim_%(src)s && b%(dst)s) *out00 = val00;
    if (idx_%(src)s +  8 < dim_%(src)s && b%(dst)s) *out08 = val08;
    if (idx_%(src)s + 16 < dim_%(src)s && b%(dst)s) *out16 = val16;
    if (idx_%(src)s + 24 < dim_%(src)s && b%(dst)s) *out24 = val24;
        copy_transpose = r"""

__global__ void copy_transpose(%(type)s* out, const %(type)s* in, %(params)s)
    int tid_x = threadIdx.x;
    int tid_y = threadIdx.y;
    int idx_%(blk)s = blockIdx.x;
    int idx_%(dst)s = blockIdx.y;


    idx_%(src)s = (idx_%(src)s << 5) + tid_x;
    idx_%(dst)s = (idx_%(dst)s << 5) + tid_y;

    bool b%(src)s    = idx_%(src)s      < dim_%(src)s;
    bool b%(dst)s_00 = idx_%(dst)s +  0 < dim_%(dst)s && b%(src)s;
    bool b%(dst)s_08 = idx_%(dst)s +  8 < dim_%(dst)s && b%(src)s;
    bool b%(dst)s_16 = idx_%(dst)s + 16 < dim_%(dst)s && b%(src)s;
    bool b%(dst)s_24 = idx_%(dst)s + 24 < dim_%(dst)s && b%(src)s;

    %(type)s val00 = 0;
    %(type)s val08 = 0;
    %(type)s val16 = 0;
    %(type)s val24 = 0;

    const %(type)s* in00 = in   + %(src_offset)s;
    const %(type)s* in08 = in00 + src_str_%(dst)s*8;
    const %(type)s* in16 = in08 + src_str_%(dst)s*8;
    const %(type)s* in24 = in16 + src_str_%(dst)s*8;

    if (b%(dst)s_00) val00 = *in00;
    if (b%(dst)s_08) val08 = *in08;
    if (b%(dst)s_16) val16 = *in16;
    if (b%(dst)s_24) val24 = *in24;

    %(type)s* out00 = out   + %(dst_offset)s;
    %(type)s* out08 = out00 + dst_str_%(dst)s*8;
    %(type)s* out16 = out08 + dst_str_%(dst)s*8;
    %(type)s* out24 = out16 + dst_str_%(dst)s*8;

    if (b%(dst)s_00) *out00 = val00;
    if (b%(dst)s_08) *out08 = val08;
    if (b%(dst)s_16) *out16 = val16;
    if (b%(dst)s_24) *out24 = val24;
    code = copy_transpose % dict(
        type=_get_register_type(dtype, memory=True),
        params=", ".join(params),
        src_offset=" + ".join(src_offset),
        dst_offset=" + ".join(dst_offset)
    # print code
    module = SourceModule(code)
    kernel = module.get_function("copy_transpose")
    kernel.prepare("PP" + ("I" * (len(params) - num_strides)) + "q" * num_strides)

    grid_x = grid_shape[src_dim]
    grid_y = grid_shape[dst_dim]
    for s in src:
        if s not in (src_dim, dst_dim):
            grid_x *= grid_shape[s]

    kernel.grid = (grid_x, grid_y, 1)
    kernel.block = (32, 8, 1)
    kernel.args = tuple(values)

    return kernel