def _get_oned_copy_kernel(dtype, shape): copy = r""" __global__ void copy_oned(%(type)s* out, const %(type)s* in, int dim, long long src_str, long long dst_str) { int tid_x = threadIdx.x; int idx = blockIdx.x; idx = (idx << 5) + tid_x; const %(type)s* in0 = in + (src_str * idx); %(type)s* out0 = out + (dst_str * idx); if(idx < dim) *out0 = *in0; } """ code = copy % dict( type=_get_register_type(dtype, memory=True) ) # print code module = SourceModule(code) kernel = module.get_function("copy_oned") kernel.prepare("PPIqq") kernel.grid = (_ceil_div(shape[0], 32), 1, 1) kernel.block = (32, 1, 1) kernel.args = (shape[0], ) return kernel
def _get_lut_bprop_kernel(dtype, in_dtype, deterministic=False): """ Builds the bprop kernel for lookup table layers based on templated code. If the deterministic version is requested, an index buffer must be passed as an argument. This index buffer re-orders items in the input tensor so that word_ids are sorted. This is required since we need to be sure that each thread only updates weights for one word id. Arguments: dtype (np.dtype): The data which the kernel will operate on. deterministic (boolean): Builds the deterministic kernel when this is set to True. """ if not deterministic: code = r""" __global__ void lut_bprop( %(in_dtype)s* inputs, %(type)s* dW, %(type)s* errors, const int nin, const int embedding_dim, const int vocab_size, const int pad_idx) { const int tid = threadIdx.x; const int bid = blockIdx.x; int word_id = inputs[bid]; int error_row = bid * embedding_dim; int output_row = word_id * embedding_dim; if(word_id != pad_idx) { for(int i = tid; i < embedding_dim; i += blockDim.x) { atomicAdd(&dW[output_row + i], errors[error_row + i]); } } } """ code = code % {"type": _get_register_type(dtype)} module = SourceModule(code, options=["--use_fast_math"]) kernel = module.get_function("lut_bprop") kernel.prepare("PPPIIIi") else: code = r""" %(common)s __global__ void lut_bprop( %(in_dtype)s* inputs, int* index_buffer, %(type)s* dW, %(type)s* errors, const int nin, const int embedding_dim, const int vocab_size, const int pad_idx %(stats_args)s) { const int tid = threadIdx.x; const int bid = blockIdx.x; int index_position = bid; int index = index_buffer[index_position]; int word_id = inputs[index] %(compute_input)s; int intermediate_max = 0; if((bid == 0 || word_id != (inputs[index_buffer[bid - 1]] %(compute_input)s)) && word_id != pad_idx) { int output_row = word_id * embedding_dim; do { int error_row = index * embedding_dim; for(int i = tid; i < embedding_dim; i += blockDim.x) { %(compute_dW_code)s } index_position++; if(index_position == gridDim.x) { break; } index = index_buffer[index_position]; } while((inputs[index] %(compute_input)s) == word_id); } %(atomic_max)s } """ code %= _configure_template_vals_bprop(in_dtype, dtype) module = SourceModule(code, options=["--use_fast_math"]) kernel = module.get_function("lut_bprop") kernel.prepare("PPPPIIIi" + flex_sig_bprop(in_dtype.str[1])) kernel.name = lut_bprop_kernel_name return kernel
def _get_sorting_kernel(kernel_id, block_size, in_dtype): """ Builds kernels used for sorting inputs. There are several kernels here corresponding to the steps in the algorithm. The algorithm works by determining the sorted position for each input item. This is done with a bucket sort algorithm, where each word_id is a bucket. The first step determines the size of each bucket (number of occurences of each word_id). Next, a prefix some is computed over the list of bucket sizes to find where each bucket will be placed in the output buffer. Finally, each thread places it's index into the correct sorted position based on the bucket start index (computed from the prefix sum) and that thread's offset into the bucket (which is taken from the output of the atomic add done in the first step.) Arguments: kernel_id (Integer): Which step to build the kernel for [0, 4] block_size (Integer): Number of threads per block for the prefix sum kernels. """ code = r""" #define THREADS %(threads)s #define STORE_BLOCKSUM %(store_blocksum)s __global__ void sort_inputs0( %(in_dtype)s* inputs, int* index_buffer, int* offset_buffer, int* word_counts, const int vocab_size, const int input_length %(stats_args)s) { const int tid = threadIdx.x + (blockDim.x * blockIdx.x); int word_id; if(tid < input_length) { word_id = inputs[tid] %(compute_input)s; offset_buffer[tid] = atomicAdd(&word_counts[word_id], 1); } } __device__ void scan(int* buffer, int* blocksum, int global_length) { const int tid = (threadIdx.x << 1) + 1; const int gid = ((threadIdx.x + (blockIdx.x * blockDim.x)) << 1) + 1; __shared__ int local_counts[THREADS * 2]; local_counts[tid] = buffer[gid]; local_counts[tid - 1] = buffer[gid - 1]; #pragma unroll for(int skip = 1; skip <= THREADS; skip <<= 1) { int mask = (skip << 1) - 1; if((tid & mask) == mask) { local_counts[tid] += local_counts[tid - skip]; } __syncthreads(); } if(tid == (THREADS * 2 - 1)) { #if STORE_BLOCKSUM blocksum[blockIdx.x] = local_counts[tid]; #endif local_counts[tid] = 0; } #pragma unroll for(int skip = THREADS; skip > 0; skip >>= 1) { int mask = (skip << 1) - 1; if((tid & mask) == mask) { int temp = local_counts[tid - skip]; local_counts[tid - skip] = local_counts[tid]; local_counts[tid] += temp; } __syncthreads(); } if(gid < global_length) { buffer[gid] = local_counts[tid]; buffer[gid - 1] = local_counts[tid - 1]; } } __global__ void sort_inputs1( %(in_dtype)s* inputs, int* index_buffer, int* offset_buffer, int* word_counts, const int vocab_size, const int input_length %(stats_args)s) { scan(word_counts, word_counts + vocab_size, vocab_size); } __global__ void sort_inputs2( %(in_dtype)s* inputs, int* index_buffer, int* offset_buffer, int* word_counts, const int vocab_size, const int input_length %(stats_args)s) { scan(word_counts + vocab_size, 0, blockDim.x); } __global__ void sort_inputs3( %(in_dtype)s* inputs, int* index_buffer, int* offset_buffer, int* word_counts, const int vocab_size, const int input_length %(stats_args)s) { const int gid = (threadIdx.x + (blockIdx.x * blockDim.x)) << 1; if(gid < vocab_size) { word_counts[gid] += word_counts[vocab_size + blockIdx.x]; word_counts[gid + 1] += word_counts[vocab_size + blockIdx.x]; } } __global__ void sort_inputs4( %(in_dtype)s* inputs, int* index_buffer, int* offset_buffer, int* word_counts, const int vocab_size, const int input_length %(stats_args)s) { const int tid = threadIdx.x + (blockDim.x * blockIdx.x); int word_id; if(tid < input_length) { word_id = inputs[tid] %(compute_input)s; int sorted_position = word_counts[word_id] + offset_buffer[tid]; index_buffer[sorted_position] = tid; } } """ code %= _configure_template_vals_sort(block_size, kernel_id, in_dtype) module = SourceModule(code, options=["--use_fast_math"]) function_name = "sort_inputs" + native_str(kernel_id) kernel = module.get_function(function_name) kernel.prepare("PPPPII" + flex_sig_sort(in_dtype.str[1])) kernel.name = lut_sort_kernel_name return kernel
def _get_copy_transpose_kernel(dtype, shape, axes=None): if len(shape) == 1: return _get_oned_copy_kernel(dtype, shape) src = list(range(len(shape))) dst = list(axes) src_dim = src[-1] dst_dim = dst[-1] # If the inner dim is the same for both, no need for shared memory tile # Then map the outer source dim to the threadIdx.y values if src_dim == dst_dim: dst_dim = src[0] shared_tile = False else: shared_tile = True src_offset = [] dst_offset = [] params = [] values = [] magic = "" # add dims for bounds checking for dim in (src_dim, dst_dim): params.append("int dim_%s" % dim) values.append(shape[dim]) # collapse src and dst shape by 32 grid_shape = list(shape) grid_shape[src_dim] = _ceil_div(shape[src_dim], 32) grid_shape[dst_dim] = _ceil_div(shape[dst_dim], 32) # get a src list without dst dim src2 = [s for s in src if s != dst_dim] # get the name of the first compound index blkx_name = compound_idx = "".join(native_str(x) for x in src2) # generate the magic number math to extract all indeces while len(src2) > 1: idx1 = src2[0] del src2[0] idx2 = "".join(native_str(i) for i in src2) div = reduce(mul, (grid_shape[i] for i in src2), 1) params.extend(p % idx2 for p in ("int magic_%s", "int shift_%s", "int div_%s")) values.extend(_magic64(div)) values.append(div) magic += r""" int idx_{1} = div64(idx_{0}, magic_{2}, shift_{2}); int idx_{2} = idx_{0} - idx_{1}*div_{2}; """.format(compound_idx, idx1, idx2) compound_idx = idx2 # Add params for src strides and generate src offset # The param values will be added externally for s in src: params.append("long long src_str_%d" % s) src_offset.append("src_str_%d*idx_%d" % (s, s)) # Add params for dst strides and generate dst offset for d in dst: params.append("long long dst_str_%d" % d) dst_offset.append("dst_str_%d*idx_%d" % (d, d)) num_strides = len(src) + len(dst) if shared_tile: copy_transpose = r""" %(common)s __global__ void copy_transpose(%(type)s* out, const %(type)s* in, %(params)s) { __shared__ %(type)s tile[32][33]; int tid_x = threadIdx.x; int tid_y = threadIdx.y; int idx_%(blk)s = blockIdx.x; int idx_%(dst)s = blockIdx.y; %(magic)s idx_%(src)s = (idx_%(src)s << 5) + tid_x; idx_%(dst)s = (idx_%(dst)s << 5) + tid_y; const %(type)s* in00 = in + %(src_offset)s; const %(type)s* in08 = in00 + src_str_%(dst)s*8; const %(type)s* in16 = in08 + src_str_%(dst)s*8; const %(type)s* in24 = in16 + src_str_%(dst)s*8; bool b%(src)s = idx_%(src)s < dim_%(src)s; if (idx_%(dst)s + 0 < dim_%(dst)s && b%(src)s) tile[tid_y + 0][tid_x] = *in00; if (idx_%(dst)s + 8 < dim_%(dst)s && b%(src)s) tile[tid_y + 8][tid_x] = *in08; if (idx_%(dst)s + 16 < dim_%(dst)s && b%(src)s) tile[tid_y + 16][tid_x] = *in16; if (idx_%(dst)s + 24 < dim_%(dst)s && b%(src)s) tile[tid_y + 24][tid_x] = *in24; __syncthreads(); %(type)s val00 = tile[tid_x][tid_y + 0]; %(type)s val08 = tile[tid_x][tid_y + 8]; %(type)s val16 = tile[tid_x][tid_y + 16]; %(type)s val24 = tile[tid_x][tid_y + 24]; idx_%(src)s += tid_y - tid_x; idx_%(dst)s += tid_x - tid_y; bool b%(dst)s = idx_%(dst)s < dim_%(dst)s; %(type)s* out00 = out + %(dst_offset)s; %(type)s* out08 = out00 + dst_str_%(src)s*8; %(type)s* out16 = out08 + dst_str_%(src)s*8; %(type)s* out24 = out16 + dst_str_%(src)s*8; if (idx_%(src)s + 0 < dim_%(src)s && b%(dst)s) *out00 = val00; if (idx_%(src)s + 8 < dim_%(src)s && b%(dst)s) *out08 = val08; if (idx_%(src)s + 16 < dim_%(src)s && b%(dst)s) *out16 = val16; if (idx_%(src)s + 24 < dim_%(src)s && b%(dst)s) *out24 = val24; } """ else: copy_transpose = r""" %(common)s __global__ void copy_transpose(%(type)s* out, const %(type)s* in, %(params)s) { int tid_x = threadIdx.x; int tid_y = threadIdx.y; int idx_%(blk)s = blockIdx.x; int idx_%(dst)s = blockIdx.y; %(magic)s idx_%(src)s = (idx_%(src)s << 5) + tid_x; idx_%(dst)s = (idx_%(dst)s << 5) + tid_y; bool b%(src)s = idx_%(src)s < dim_%(src)s; bool b%(dst)s_00 = idx_%(dst)s + 0 < dim_%(dst)s && b%(src)s; bool b%(dst)s_08 = idx_%(dst)s + 8 < dim_%(dst)s && b%(src)s; bool b%(dst)s_16 = idx_%(dst)s + 16 < dim_%(dst)s && b%(src)s; bool b%(dst)s_24 = idx_%(dst)s + 24 < dim_%(dst)s && b%(src)s; %(type)s val00 = 0; %(type)s val08 = 0; %(type)s val16 = 0; %(type)s val24 = 0; const %(type)s* in00 = in + %(src_offset)s; const %(type)s* in08 = in00 + src_str_%(dst)s*8; const %(type)s* in16 = in08 + src_str_%(dst)s*8; const %(type)s* in24 = in16 + src_str_%(dst)s*8; if (b%(dst)s_00) val00 = *in00; if (b%(dst)s_08) val08 = *in08; if (b%(dst)s_16) val16 = *in16; if (b%(dst)s_24) val24 = *in24; %(type)s* out00 = out + %(dst_offset)s; %(type)s* out08 = out00 + dst_str_%(dst)s*8; %(type)s* out16 = out08 + dst_str_%(dst)s*8; %(type)s* out24 = out16 + dst_str_%(dst)s*8; if (b%(dst)s_00) *out00 = val00; if (b%(dst)s_08) *out08 = val08; if (b%(dst)s_16) *out16 = val16; if (b%(dst)s_24) *out24 = val24; } """ code = copy_transpose % dict( common=_div64, type=_get_register_type(dtype, memory=True), params=", ".join(params), blk=blkx_name, src=src_dim, dst=dst_dim, magic=magic, src_offset=" + ".join(src_offset), dst_offset=" + ".join(dst_offset) ) # print code module = SourceModule(code) kernel = module.get_function("copy_transpose") kernel.prepare("PP" + ("I" * (len(params) - num_strides)) + "q" * num_strides) grid_x = grid_shape[src_dim] grid_y = grid_shape[dst_dim] for s in src: if s not in (src_dim, dst_dim): grid_x *= grid_shape[s] kernel.grid = (grid_x, grid_y, 1) kernel.block = (32, 8, 1) kernel.args = tuple(values) return kernel