def _get_compound_kernel(type_args, compute_capability): """ generate compound kernel for the optree from type_args """ # from the stack, rebuild a mutable tree tree = _build_tree(type_args) # _print_tree(tree) # exit() # split all reductions and post reduction scalar operations out of the tree # sub-trees are converted to stacks and pushed onto stages list stages = _split_stages(tree) # _print_tree(tree) # exit() # set the final stage type to type of output (scalar or elementwise) last_stage = "red_out" if tree[1] == 1 else "ew_out" # convert the remainder of tree to stack stages.append((last_stage, _post_order(tree))) # for stage, stage_data in enumerate(stages): # print stage_data[0], stage # for s in stage_data[1]: print s # print # exit() stack = list() placeholders = list() stage_out_reg = dict() arg_dict = dict() array_ids = set() fp16In = False rand_init = False rand_func = False threads = type_args[-1][3] template = _ew_template template_vals = { "threads": threads, "name": _get_kernel_name(), "common": list(), "inits": list(), "finish": list(), } for stage, stage_data in enumerate(stages): stage_type, stage_stack = stage_data new_placeholders = list() # build out the template as we process stages if stage_type == "reduction": new_placeholders.append("loads%d" % stage) new_placeholders.append("ops%d" % stage) new_placeholders.append("shfl_red%d" % stage) template += _stage_template["loop"].format(stage) if threads > 32: new_placeholders.append("var_red%d" % stage) new_placeholders.append("share1_red%d" % stage) new_placeholders.append("share2_red%d" % stage) template += _stage_template["red"].format(stage) else: template += _stage_template["red32"].format(stage) elif stage_type == "scalar": new_placeholders.append("ops%d" % stage) template += _stage_template["red_ops"].format(stage) elif stage_type == "red_out": new_placeholders.append("ops%d" % stage) template += _stage_template["red_out"].format(stage) else: # ew_out new_placeholders.append("loads%d" % stage) new_placeholders.append("ops%d" % stage) template += _stage_template["loop"].format(stage) for key in new_placeholders: template_vals[key] = [] placeholders.extend(new_placeholders) for arg_i, arg in enumerate(stage_stack): arg_type, arg_id = arg[0:2] # Array operands if arg_type is ng.GPUTensor: dtype, take_axis = arg[2:4] is_out_tensor = True if stage == len( stages) - 1 and arg_i == 0 else False # first arg is output array, don't put on stack if is_out_tensor: out_dtype = dtype out_take = take_axis else: stack.append("a%d" % arg_id) # 0: arg_id, 1: stage, 2: type, 3: cvt ew_dtype = _ew_types[dtype] fmt = (arg_id, stage, ew_dtype["type"], ew_dtype["cvt"]) # First time we see a tensor initialize everything if arg_id not in array_ids: array_ids.add(arg_id) array_ids.add((arg_id, stage)) sig = "Pii" if take_axis > 0: sig += "P" # output tensor if is_out_tensor: ew_out = _ew_strings["out%d" % take_axis] arguments = ew_out["arguments"].format(*fmt) template_vals["inits"].append( ew_out["inits"].format(*fmt)) # input tensors else: ew_in = _ew_strings["in%d" % take_axis] loads = "loads%d" % stage arguments = ew_in["arguments"].format(*fmt) template_vals["inits"].append( ew_in["inits"].format(*fmt)) template_vals[loads].append( ew_in["loads"].format(*fmt)) if dtype == 'f2' and not fp16In: template_vals["common"].append(_common_fp16_to_fp32) fp16In = True arg_dict[arg] = (sig, arguments) # Subsequent times we see a tensor just initialize inits and # loads elif (arg_id, stage) not in array_ids: array_ids.add((arg_id, stage)) ew_in = _ew_strings["in%d" % take_axis] loads = "loads%d" % stage template_vals["inits"].append(ew_in["inits"].format(*fmt)) template_vals[loads].append(ew_in["loads"].format(*fmt)) # Constant operands elif arg_type is float: stack.append("c%d" % arg_id) if arg not in arg_dict: arg_dict[arg] = ( "f", _ew_strings["const"]["arguments"].format(arg_id)) # Operations (arg_type = op_name) else: if arg_type == "assign": ops = "ops%d" % stage # loop end condition for last stage sig = "i" arguments = ["const int n%d" % stage] # rounding mode if arg[2]: mode = "random" sig += "i" arguments.append("const int mantissa_bits") if not rand_init: rand_init = _init_rand(template_vals) template_vals["inits"].append(_init_rand_round_func) else: mode = "nearest" arg_dict[arg] = (sig, ", ".join(arguments)) out_val = stack.pop() # if the last stack value came from an argmax/min just do # implicit type conversion if out_val[0] == "i" and out_dtype[0] in "iu": ew_round = None else: ew_round = _ew_strings["round"][ mode].get(out_dtype, None) ew_common = _common_round[mode].get(out_dtype, None) if ew_common: template_vals["common"].append(ew_common) if ew_round: round_val = "r%d" % arg_id template_vals[ops].append( ew_round.format(round_val, out_val)) else: round_val = out_val template_vals[ops].append( _ew_strings["out%d" % out_take]["output"].format(round_val)) elif arg in stage_out_reg: stack.append(stage_out_reg[arg]) elif arg_type in _float_ops: if len(template_vals["name"]) < 16: template_vals["name"].append(arg_type) ops = "ops%d" % stage (num_ops, op_code) = _float_ops[arg_type] if arg_type == "rand": if not rand_init: rand_init = _init_rand(template_vals) if not rand_func: template_vals["common"].append(_common_frand) rand_func = True op_list = ["r%d" % arg_id] # build the operands from the stack for i in range(num_ops): op_list.append(stack.pop()) if arg_type == "onehot": hot_axis = arg[2] test_val = "i" if hot_axis else "bid" ew_in = _ew_strings[arg_type + native_str(hot_axis)] loads = "loads%d" % stage template_vals["inits"].append( ew_in["inits"].format(arg_id)) template_vals[loads].append( ew_in["loads"].format(arg_id)) op_list.append("onehot%d" % arg_id) op_list.append(test_val) arg_dict[arg] = ( "P", ew_in["arguments"].format(arg_id)) template_vals[ops].append(op_code.format(*op_list)) # if this is the last op on the current stack, store its register stage # in the stage output dict if arg_i == len(stage_stack) - 1: stage_out_reg[arg] = op_list[0] # otherwise push the reg onto the stack as normal else: stack.append(op_list[0]) elif arg_type in _reduction_ops: if len(template_vals["name"]) < 16: template_vals["name"].append(arg_type) # loop end condition for current stage # add regardless of duplicate reduction stage arg_dict[arg] = ("i", "const int n%d" % stage) # avoid float conversion for argmax/min reg = "i" if "arg" == arg_type[0:3] else "r" ops = "ops%d" % stage shfl_red = "shfl_red%d" % stage red_arg = "%s%d" % (reg, arg_id) red_strings = _reduction_ops[arg_type] stack_arg = stack.pop() template_vals["inits"].append( red_strings["inits"].format(red_arg)) template_vals[ops].append( red_strings["ops"].format(red_arg, stack_arg)) template_vals[shfl_red].append( red_strings["shfl_red"].format(red_arg)) if threads > 32: var_red = "var_red%d" % stage shr1_red = "share1_red%d" % stage shr2_red = "share2_red%d" % stage template_vals[var_red].append(red_arg) template_vals[shr1_red].append( red_strings["share1_red"].format(red_arg)) template_vals[shr2_red].append( red_strings["share2_red"].format(red_arg)) # reduction ops are always the last on the stack # just store the register state in the stage output dict stage_out_reg[arg] = red_arg else: raise ValueError("Bad op type.") if (compute_capability[0] == 3 and compute_capability[1] < 5) or compute_capability[0] < 3: template_vals["common"].append(_common_kepler) template += _fin_template # since we reorderd the operations we need to generate the argument list # in the original order sig = "P" arguments = list() unused = 1 for arg in type_args: params = arg_dict.get(arg, False) if params: sig += params[0] arguments.append(params[1]) del arg_dict[arg] # fill in the loop counter for the duplicate reductions that were # removed elif arg[0] in _reduction_ops: sig += "i" arguments.append("const int unused%d" % unused) unused += 1 # convert lists to strings template_vals["name"] = "_".join(template_vals["name"]) template_vals["common"] = "\n".join(template_vals["common"]) template_vals["arguments"] = ",\n ".join(arguments) template_vals["inits"] = "\n ".join(template_vals["inits"]) template_vals["finish"] = "\n".join(template_vals["finish"]) # add the dynamic placeholders: loads#, ops#, reduction# for key in placeholders: template_vals[key] = "\n ".join(template_vals[key]) # populate the template code = template % template_vals # debugging: # print "Compiling %s" % template_vals["name"] # f = open("kernel.cu", "w") # f = open("%s.cu" % template_vals["name"], "w") # print >>f, code # f.close() # ,"-G" , keep=False # module = SourceModule(code, options=["--use_fast_math"]) module = SourceModule(code, options=[]) kernel = module.get_function(template_vals["name"]) kernel.name = template_vals["name"] kernel.prepare(sig) return kernel
def _get_compensated_sum_kernel(dtype, rounding): _compensated_sum = r""" %(common)s __global__ void compensated_sum(unsigned* rand_state, %(type)s* a_sum, %(type)s* a_cmp, const %(type)s* a_add, float cmp_scale, float add_scale, int row_strd, int col_strd, int n, int mantissa_bits) { const int tid = threadIdx.x; const int bid = blockIdx.x; int offset = bid * row_strd + tid * col_strd; int inc = 32 * col_strd; a_sum += offset; a_cmp += offset; a_add += offset; %(inits)s for (int i = tid; i < n; i += 32) { float s32 = %(cvt)s(__ldg((const %(type)s*)a_sum)); float c32 = %(cvt)s(__ldg((const %(type)s*)a_cmp)); float a32 = %(cvt)s(__ldg(a_add)); // Adjust amount to add by previous compensation float y32 = a32 * add_scale - c32 * cmp_scale; // Do the accumulation and truncate to the storage type float rnd_sum = s32 + y32; %(rnd_sum)s // Convert accumulation back to fp32 so we can do more math on it float t32 = %(cvt)s(t16); // recover the low order bits that were lost in the truncation float rnd_cmp = (t32 - s32) - y32; %(rnd_cmp)s *a_sum = t16; *a_cmp = c16; a_sum += inc; a_cmp += inc; a_add += inc; } %(finish)s } """ template_vals = dict() for key in ("common", "inits", "finish"): template_vals[key] = "" if dtype == "f2": template_vals["common"] += _common_fp16_to_fp32 if rounding: template_vals["common"] += _common_urand_gen template_vals["common"] += _common_round["nearest"].get(dtype, "") template_vals["inits"] += _init_rand_func + _init_rand_round_func template_vals["finish"] += _finish_rand_func mode = "random" else: mode = "nearest" template_vals["common"] += _common_round[mode].get(dtype, "") template_vals["type"] = _ew_types[dtype]["type"] template_vals["cvt"] = _ew_types[dtype]["cvt"] no_op = "float {0} = {1};" rnd_sum = _ew_strings["round"][mode].get(dtype, no_op) rnd_cmp = _ew_strings["round"]["nearest"].get(dtype, no_op) template_vals["rnd_sum"] = rnd_sum.format("t16", "rnd_sum") template_vals["rnd_cmp"] = rnd_cmp.format("c16", "rnd_cmp") code = _compensated_sum % template_vals # f = open("compensated_sum.cu", "w") # print >>f, code # f.close() module = SourceModule(code) kernel = module.get_function("compensated_sum") kernel.prepare("PPPPffiiii") return kernel
def _get_hist_kernel(dtype_str, nbins, offset): """ Build a kernel to compute a 64 bin histogram. Use templating to generate a customized kernel depending on the input data type. Memoized to avoid compiling the same kernel twice. """ type_str = _ew_types[dtype_str[1:]] from string import Template code = Template(_common_fp16_to_fp32 + r""" #define MAX(a,b) (a > b ? a : b) #define MIN(a,b) (a < b ? a : b) __global__ void kernel_histo ( int* d_hist, const $in_type* a1_in, int strides, int size) { const int tid = threadIdx.x; const int bid = blockIdx.x; __shared__ int s[$nbins]; if(tid < $nbins){ s[tid] = 0; } if(bid == 0 && tid < $nbins){ d_hist[tid] = 0; } for (int i = tid + blockDim.x*bid; i < size; i += strides) { float a1 = $convert_to_float(__ldg(a1_in + i)); float absval = fabs(a1); float logabs = round(log2f(absval)); int bin = MIN($nbins-1, MAX(0, logabs-($offset))); atomicAdd(&s[bin], 1); } __syncthreads(); if(tid < $nbins){ atomicAdd(&d_hist[tid], s[tid]); } } """) module = SourceModule(code.substitute(in_type=type_str['type'], convert_to_float=type_str['cvt'], nbins=nbins, offset=offset), options=[]) kernel = module.get_function("kernel_histo") kernel.prepare("PPII") return kernel
def _get_sorting_kernel(kernel_id, block_size): """ Builds kernels used for sorting inputs. There are several kernels here corresponding to the steps in the algorithm. The algorithm works by determining the sorted position for each input item. This is done with a bucket sort algorithm, where each word_id is a bucket. The first step determines the size of each bucket (number of occurences of each word_id). Next, a prefix some is computed over the list of bucket sizes to find where each bucket will be placed in the output buffer. Finally, each thread places it's index into the correct sorted position based on the bucket start index (computed from the prefix sum) and that thread's offset into the bucket (which is taken from the output of the atomic add done in the first step.) Arguments: kernel_id (Integer): Which step to build the kernel for [0, 4] block_size (Integer): Number of threads per block for the prefix sum kernels. """ code = r""" #define THREADS %(threads)s #define STORE_BLOCKSUM %(store_blocksum)s __global__ void sort_inputs0( int* inputs, int* index_buffer, int* offset_buffer, int* word_counts, const int vocab_size, const int input_length) { const int tid = threadIdx.x + (blockDim.x * blockIdx.x); int word_id; if(tid < input_length) { word_id = inputs[tid]; offset_buffer[tid] = atomicAdd(&word_counts[word_id], 1); } } __device__ void scan(int* buffer, int* blocksum, int global_length) { const int tid = (threadIdx.x << 1) + 1; const int gid = ((threadIdx.x + (blockIdx.x * blockDim.x)) << 1) + 1; __shared__ int local_counts[THREADS * 2]; local_counts[tid] = buffer[gid]; local_counts[tid - 1] = buffer[gid - 1]; #pragma unroll for(int skip = 1; skip <= THREADS; skip <<= 1) { int mask = (skip << 1) - 1; if((tid & mask) == mask) { local_counts[tid] += local_counts[tid - skip]; } __syncthreads(); } if(tid == (THREADS * 2 - 1)) { #if STORE_BLOCKSUM blocksum[blockIdx.x] = local_counts[tid]; #endif local_counts[tid] = 0; } #pragma unroll for(int skip = THREADS; skip > 0; skip >>= 1) { int mask = (skip << 1) - 1; if((tid & mask) == mask) { int temp = local_counts[tid - skip]; local_counts[tid - skip] = local_counts[tid]; local_counts[tid] += temp; } __syncthreads(); } if(gid < global_length) { buffer[gid] = local_counts[tid]; buffer[gid - 1] = local_counts[tid - 1]; } } __global__ void sort_inputs1( int* inputs, int* index_buffer, int* offset_buffer, int* word_counts, const int vocab_size, const int input_length) { scan(word_counts, word_counts + vocab_size, vocab_size); } __global__ void sort_inputs2( int* inputs, int* index_buffer, int* offset_buffer, int* word_counts, const int vocab_size, const int input_length) { scan(word_counts + vocab_size, 0, blockDim.x); } __global__ void sort_inputs3( int* inputs, int* index_buffer, int* offset_buffer, int* word_counts, const int vocab_size, const int input_length) { const int gid = (threadIdx.x + (blockIdx.x * blockDim.x)) << 1; if(gid < vocab_size) { word_counts[gid] += word_counts[vocab_size + blockIdx.x]; word_counts[gid + 1] += word_counts[vocab_size + blockIdx.x]; } } __global__ void sort_inputs4( int* inputs, int* index_buffer, int* offset_buffer, int* word_counts, const int vocab_size, const int input_length) { const int tid = threadIdx.x + (blockDim.x * blockIdx.x); int word_id; if(tid < input_length) { word_id = inputs[tid]; int sorted_position = word_counts[word_id] + offset_buffer[tid]; index_buffer[sorted_position] = tid; } } """ code = code % { "threads": block_size, "store_blocksum": (1 if kernel_id == 1 else 0) } module = SourceModule(code, options=["--use_fast_math"]) function_name = "sort_inputs" + native_str(kernel_id) kernel = module.get_function(function_name) kernel.prepare("PPPPII") kernel.name = "sort_inputs" return kernel
def _get_fprop_roipooling(clss): code = r""" #define FLT_MAX 3.402823466E+38F __global__ void fprop_roipooling(const int nthreads, const int num_rois, const int img_count, const int channels, const int height, const int width, const int pooled_height, const int pooled_width, const float* bottom_data, const float* bottom_rois, float* top_data, int* argmax_data, const float spatial_scale) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; \ index < (nthreads); index += blockDim.x * gridDim.x){ // (c, ph, pw, n) is an element in the pooled output int n = index % num_rois; int pw = (index / num_rois) % pooled_width; int ph = (index / num_rois / pooled_width) % pooled_height; int c = index / num_rois / pooled_width / pooled_height; bottom_rois += n * 5; int roi_batch_ind = bottom_rois[0]; int roi_start_w = round(bottom_rois[1] * spatial_scale); int roi_start_h = round(bottom_rois[2] * spatial_scale); int roi_end_w = round(bottom_rois[3] * spatial_scale); int roi_end_h = round(bottom_rois[4] * spatial_scale); // Force malformed ROIs to be 1x1 int roi_width = max(roi_end_w - roi_start_w + 1, 1); int roi_height = max(roi_end_h - roi_start_h + 1, 1); float bin_size_h = static_cast<float>(roi_height) / static_cast<float>(pooled_height); float bin_size_w = static_cast<float>(roi_width) / static_cast<float>(pooled_width); int hstart = static_cast<int>(floor(static_cast<float>(ph) * bin_size_h)); int wstart = static_cast<int>(floor(static_cast<float>(pw) * bin_size_w)); int hend = static_cast<int>(ceil(static_cast<float>(ph + 1) * bin_size_h)); int wend = static_cast<int>(ceil(static_cast<float>(pw + 1) * bin_size_w)); // Add roi offsets and clip to input boundaries hstart = min(max(hstart + roi_start_h, 0), height); hend = min(max(hend + roi_start_h, 0), height); wstart = min(max(wstart + roi_start_w, 0), width); wend = min(max(wend + roi_start_w, 0), width); bool is_empty = (hend <= hstart) || (wend <= wstart); // Define an empty pooling region to be zero float maxval = is_empty ? 0 : -FLT_MAX; // If nothing is pooled, argmax = -1 causes nothing to be backprop'd int maxidx = -1; bottom_data += c * height * width * img_count; for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { int bottom_index = h * width * img_count + w * img_count + roi_batch_ind; if (bottom_data[bottom_index] > maxval) { maxval = bottom_data[bottom_index]; maxidx = bottom_index; } } } top_data[index] = maxval; argmax_data[index] = maxidx; // Notice the maxidx (from bottom_index) is relative to the dimension // (h, w, img_count) of the feature map, so max value is HWN } } """ module = SourceModule(code) kernel = module.get_function("fprop_roipooling") sig = "8I 4P 1f" kernel.prepare(sig) return kernel
def _get_bprop_roipooling(clss): code = r""" __global__ void bprop_roipooling(const int nthreads, const int num_rois, const int img_count, const int channels, const int height, const int width, const int pooled_height, const int pooled_width, const float* top_diff, const float* bottom_rois, float* bottom_diff, const int* argmax_data, const float spatial_scale) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; \ index < (nthreads); index += blockDim.x * gridDim.x){ // (c, h, w, n) coords in bottom data on feature map int n = index % img_count; int w = (index / img_count) % width; int h = (index / img_count / width) % height; int c = index / img_count/ width / height; float gradient = 0; // Accumulate gradient over all ROIs that pooled this element for (int roi_n = 0; roi_n < num_rois; ++roi_n) { const float* offset_bottom_rois = bottom_rois + roi_n * 5; int roi_batch_ind = offset_bottom_rois[0]; // Skip if ROI's batch index doesn't match n if (n != roi_batch_ind) { continue; } int roi_start_w = round(offset_bottom_rois[1] * spatial_scale); int roi_start_h = round(offset_bottom_rois[2] * spatial_scale); int roi_end_w = round(offset_bottom_rois[3] * spatial_scale); int roi_end_h = round(offset_bottom_rois[4] * spatial_scale); // Skip if ROI doesn't include (h, w) const bool in_roi = (w >= roi_start_w && w <= roi_end_w && h >= roi_start_h && h <= roi_end_h); if (!in_roi) { continue; } int offset = c * pooled_height * pooled_width * num_rois; const float* offset_top_diff = top_diff + offset; const int* offset_argmax_data = argmax_data + offset; // Compute feasible set of pooled units that could have pooled // this bottom unit // Force malformed ROIs to be 1x1 int roi_width = max(roi_end_w - roi_start_w + 1, 1); int roi_height = max(roi_end_h - roi_start_h + 1, 1); float bin_size_h = static_cast<float>(roi_height) / static_cast<float>(pooled_height); float bin_size_w = static_cast<float>(roi_width) / static_cast<float>(pooled_width); int phstart = floor(static_cast<float>(h - roi_start_h) / bin_size_h); int phend = ceil(static_cast<float>(h - roi_start_h + 1) / bin_size_h); int pwstart = floor(static_cast<float>(w - roi_start_w) / bin_size_w); int pwend = ceil(static_cast<float>(w - roi_start_w + 1) / bin_size_w); phstart = min(max(phstart, 0), pooled_height); phend = min(max(phend, 0), pooled_height); pwstart = min(max(pwstart, 0), pooled_width); pwend = min(max(pwend, 0), pooled_width); for (int ph = phstart; ph < phend; ++ph) { for (int pw = pwstart; pw < pwend; ++pw) { int top_index = ph * pooled_width * num_rois + pw * num_rois + roi_n; int bottom_index = h * width * img_count + w * img_count + roi_batch_ind; if (offset_argmax_data[top_index] == bottom_index) { gradient += offset_top_diff[top_index]; } } } } bottom_diff[index] = gradient; } } """ module = SourceModule(code) kernel = module.get_function("bprop_roipooling") sig = "8I 4P 1f" kernel.prepare(sig) return kernel
def _get_bn_fprop_kernel(dtype, threads, compute_capability): if threads > 32: shr_code = "__shared__ float sPartials[THREADS];" red_code = r""" sPartials[tid] = xvar; __syncthreads(); #pragma unroll for (int a = THREADS >> 1; a > 32; a >>= 1) { if ( tid < a ) sPartials[tid] += sPartials[tid + a]; __syncthreads(); } if ( tid < 32 ) { xvar = sPartials[tid] + sPartials[tid + 32]; #pragma unroll for (int i = 16; i > 0; i >>= 1) xvar += __shfl_xor(xvar, i); sPartials[tid] = xvar * rcpN; } __syncthreads(); xvar = sPartials[0]; """ else: shr_code = "" red_code = r""" #pragma unroll for (int i = 16; i > 0; i >>= 1) xvar += __shfl_xor(xvar, i); xvar *= rcpN; """ code = r""" #define THREADS %(threads)s %(common)s %(binary)s __global__ void batchnorm_fprop ( %(type)s* y_out, float* xvar_out, float* gmean_out, float* gvar_out, const %(type)s* x_in, const float* xsum_in, const float* gmean_in, const float* gvar_in, const float* gamma_in, const float* beta_in, const float eps, const float rho, const float accumbeta, const int N, const int relu, bool binary) { %(share)s const int tid = threadIdx.x; const int bid = blockIdx.x; int offset = bid * N; const %(type)s* x_in0 = x_in + offset + tid; const float rcpN = 1.0f/(float)N; float xmean = __ldg(xsum_in + bid) * rcpN; float xvar = 0.0f; for (int i = tid; i < N; i += THREADS) { float x = %(cvt)s(__ldg(x_in0)); x_in0 += THREADS; x -= xmean; if (binary) { xvar += shift_element(x, x, true); } else { xvar += x * x; } } %(red)s float gamma = __ldg(gamma_in + bid); float beta = __ldg(beta_in + bid); if ( tid == 0 ) { float gmean = __ldg(gmean_in + bid); float gvar = __ldg(gvar_in + bid); *(xvar_out + bid) = xvar; *(gmean_out + bid) = gmean * rho + (1.0f - rho) * xmean; *(gvar_out + bid) = gvar * rho + (1.0f - rho) * xvar; } float xvar_rcp_sqrt = 1.0f / sqrtf(xvar + eps); int start = N - (THREADS*4 - tid); offset += start; x_in += offset; y_out += offset; for (int i = start; i >= -THREADS*3; i -= THREADS*4) { float x0 = i >= -THREADS*0 ? %(cvt)s(__ldg(x_in + THREADS*0)) : 0.0f; float x1 = i >= -THREADS*1 ? %(cvt)s(__ldg(x_in + THREADS*1)) : 0.0f; float x2 = i >= -THREADS*2 ? %(cvt)s(__ldg(x_in + THREADS*2)) : 0.0f; float x3 = %(cvt)s(__ldg(x_in + THREADS*3)); x_in -= THREADS*4; float xhat0 = 0.0f; float xhat1 = 0.0f; float xhat2 = 0.0f; float xhat3 = 0.0f; float y0 = 0.0f; float y1 = 0.0f; float y2 = 0.0f; float y3 = 0.0f; if (binary) { xhat0 = shift_element(x0 - xmean, xvar_rcp_sqrt, true); xhat1 = shift_element(x1 - xmean, xvar_rcp_sqrt, true); xhat2 = shift_element(x2 - xmean, xvar_rcp_sqrt, true); xhat3 = shift_element(x3 - xmean, xvar_rcp_sqrt, true); y0 = shift_element(xhat0, gamma, true) + beta; y1 = shift_element(xhat1, gamma, true) + beta; y2 = shift_element(xhat2, gamma, true) + beta; y3 = shift_element(xhat3, gamma, true) + beta; } else { xhat0 = (x0 - xmean) * xvar_rcp_sqrt; xhat1 = (x1 - xmean) * xvar_rcp_sqrt; xhat2 = (x2 - xmean) * xvar_rcp_sqrt; xhat3 = (x3 - xmean) * xvar_rcp_sqrt; y0 = xhat0 * gamma + beta; y1 = xhat1 * gamma + beta; y2 = xhat2 * gamma + beta; y3 = xhat3 * gamma + beta; } if (relu) { y0 = fmaxf(y0, 0.0f); y1 = fmaxf(y1, 0.0f); y2 = fmaxf(y2, 0.0f); y3 = fmaxf(y3, 0.0f); } %(y0_out)s %(y1_out)s %(y2_out)s %(y3_out)s if (accumbeta == 0.0) { if (i >= -THREADS*0) *(y_out + THREADS*0) = y0_val; if (i >= -THREADS*1) *(y_out + THREADS*1) = y1_val; if (i >= -THREADS*2) *(y_out + THREADS*2) = y2_val; *(y_out + THREADS*3) = y3_val; } else { if (i >= -THREADS*0) *(y_out + THREADS*0) = y_out[THREADS*0] * accumbeta + y0_val; if (i >= -THREADS*1) *(y_out + THREADS*1) = y_out[THREADS*1] * accumbeta + y1_val; if (i >= -THREADS*2) *(y_out + THREADS*2) = y_out[THREADS*2] * accumbeta + y2_val; *(y_out + THREADS*3) = y_out[THREADS*3] * accumbeta + y3_val; } y_out -= THREADS*4; } } """ out_code = _ew_strings["round"]["nearest"].get(dtype, "float {0} = {1};") common_code = _common_round["nearest"].get(dtype, "") if dtype == "f2": common_code += _common_fp16_to_fp32 if (compute_capability[0] == 3 and compute_capability[1] < 5) or compute_capability[0] < 3: common_code += _common_kepler code = code % { "common" : common_code, "binary" : shift_element(), "share" : shr_code, "red" : red_code, "threads" : threads, "type" : _ew_types[dtype]["type"], "cvt" : _ew_types[dtype]["cvt"], "y0_out" : out_code.format("y0_val", "y0"), "y1_out" : out_code.format("y1_val", "y1"), "y2_out" : out_code.format("y2_val", "y2"), "y3_out" : out_code.format("y3_val", "y3"), } module = SourceModule(code, options=["--use_fast_math"]) kernel = module.get_function("batchnorm_fprop") kernel.prepare("PPPPPPPPPPfffIII") kernel.name = "batchnorm_fprop" return kernel
def _get_conv_kernel(dtype, filter_size, bsum, operation, filter_bounds_check=False, debug=False): """ Builds the convolution kernel for a specified filter size. Arguments: dtype (np.dtype): The data type which the kernel will operate on. filter_size (int): Total number of elements per filter (R * S) bsum (boolean): If set to true, kernel will include code to compute batch sum during fprop operation (string): Determines which kernel to build. options follow: 'fprop': Forward propagation of activations. 'bprop': Backward propagation of error. 'update': Computes gradients for filter weights based on error and inputs. filter_bounds_check (boolean): Checks if filter weight is in bounds when K is not a multiple of 32. debug (boolean): When set to true, kernels will be compiled with debug symbols. """ assert operation in ["fprop", "bprop", "update"] if operation == "fprop" or operation == "update": lut_code = r""" if(tid < 32) { int rs = tid; int base_x, base_y; base_x = output_pixel_x * stride_w - padding_w; base_y = output_pixel_y * stride_h - padding_h; unsigned int mask = (1 << tid) - 1; while(rs < FILTER_SIZE) { int filter_x, filter_y; _idiv_magic32(rs, magic_s, shift_s, S, filter_y, filter_x); int index_x = base_x + filter_x * dilation_w; int index_y = base_y + filter_y * dilation_h; //Check if the index is valid int in_bounds = (index_x >= 0 && index_x < W && index_y >= 0 && index_y < H); unsigned int threads_in_bounds = __ballot(in_bounds); //Store lookup table entry if(in_bounds) { int2 lut_entry; lut_entry.x = ((index_y * W + index_x) * N) >> 2; lut_entry.y = (rs * K) >> 2; int index = lut_size_local + __popc(threads_in_bounds & mask); lookup_table[index] = lut_entry; } lut_size_local += __popc(threads_in_bounds); rs += 32; } } """ elif operation == "bprop": lut_code = r""" if(tid < 32) { int rs = tid; int base_q, base_p; base_q = output_pixel_x - ((S - 1) * dilation_w - padding_w); base_p = output_pixel_y - ((R - 1) * dilation_h - padding_h); unsigned int mask = (1 << tid) - 1; while(rs < FILTER_SIZE) { int filter_x, filter_y; _idiv_magic32(rs, magic_s, shift_s, S, filter_y, filter_x); int index_q = base_q + filter_x * dilation_w; int index_p = base_p + filter_y * dilation_h; //Check if the index is valid int in_bounds = (((index_q % stride_w) | (index_p % stride_h)) == 0); index_q /= stride_w; index_p /= stride_h; in_bounds = in_bounds && (index_q >= 0 && index_q < W && index_p >= 0 && index_p < H); unsigned int threads_in_bounds = __ballot(in_bounds); //Store lookup table entry if(in_bounds) { int2 lut_entry; lut_entry.x = (((index_p * W) + index_q) * N) >> 2; lut_entry.y = (rs * K) >> 2; int index = lut_size_local + __popc(threads_in_bounds & mask); lookup_table[index] = lut_entry; } lut_size_local += __popc(threads_in_bounds); rs += 32; } } """ if bsum: bsum_code = r""" float local_bsum = result[q_offset].f[0] + result[q_offset].f[1] + result[q_offset].f[2] + result[q_offset].f[3]; atomicAdd(&bsum[filter_id], local_bsum); """ else: bsum_code = "" if operation == "update": a_name = "image" b_name = "error" else: if operation == "fprop": a_name = "image" b_name = "filter" elif operation == "bprop": a_name = "error" b_name = "filter" if filter_bounds_check: filter_load_cond = "int filter_load_in_bounds = (((filter_id + threadIdx.x) << 2) < K);" check_filter_cond = "(!filter_load_in_bounds) ? make_float4(0, 0, 0, 0) :" else: filter_load_cond = "" check_filter_cond = "" header_code = r""" #define TILE_DIM 32 #define ITEMS_PER_THREAD 4 #define THREADS_DIM 8 #define REG_TILE_X 4 #define REG_TILE_Y 4 #define THREADS_DIM_X 8 #define THREADS_DIM_Y 8 #define SM_TILE_X (REG_TILE_X * THREADS_DIM_X) #define SM_TILE_Y (REG_TILE_Y * THREADS_DIM_Y) #define NUM_ROWS 8 #define FILTER_SIZE %(filter_size)s #define MAGIC_FILTER_SIZE %(magic_filter_size)s #define SHIFT_FILTER_SIZE %(shift_filter_size)s typedef union Matrix { %(type)s4 f4; %(type)s f[4]; } Matrix; __device__ inline void _idiv_fast(int numerator, int denominator, float rcp, int& result, int& remainder) { result = (int)((float)numerator * rcp); remainder = numerator - (result * denominator); result = (remainder >= denominator) ? (result + 1) : result; remainder = (remainder >= denominator) ? (remainder - denominator) : remainder; } __device__ inline void _idiv_magic(int numerator, unsigned int magic, unsigned int shift, int denominator, int& result, int& remainder) { if(magic == 1) { result = numerator >> shift; } else { unsigned long long res64 = numerator * (unsigned long long)magic; result = ((int)(res64 >> 32) >> shift); } remainder = numerator - (result * denominator); } __device__ inline void _idiv_magic32(int numerator, unsigned int magic, unsigned int shift, int denominator, int& result, int& remainder) { if(magic == 1) { result = numerator >> shift; } else { result = ((numerator * magic) >> shift); } remainder = numerator - (result * denominator); } //Note: N and K must be multiples of 4 //blockIdx.x is gemm tile id (K dimension) and output pixel id //blockIdx.y is gemm tile id (N dimension) //threadIdx.x is gemm tile offset (K dimension) //threadIdx.y is gemm tile offset (N dimension) __global__ void conv_%(operation)s( %(type)s alpha, %(type)s beta, Matrix *I, Matrix *F, Matrix *O, float* bsum, int C, int D, int H, int W, int N, int T, int R, int S, int K, int M, int P, int Q, int stride_w, int stride_h, int padding_w, int padding_h, int dilation_w, int dilation_h, int input_channel_size, int filter_channel_size, int output_filter_size, int output_pixels, int grid_p, int grid_q, unsigned int magic_pq, unsigned int shift_pq, unsigned int magic_q, unsigned int shift_q, unsigned int magic_s, unsigned int shift_s) """ code = r""" { __shared__ int2 lookup_table[FILTER_SIZE]; __shared__ int lut_size; __shared__ Matrix %(a_name)s_data[NUM_ROWS][THREADS_DIM_X]; __shared__ Matrix %(b_name)s_data[NUM_ROWS][THREADS_DIM_Y]; int lut_size_local = 0; //TODO: Use square access pattern to image data to increase cache hits int output_pixel, image_id; _idiv_magic(blockIdx.x, magic_pq, shift_pq, output_pixels, image_id, output_pixel); image_id = (image_id * blockDim.x); //Zig zag along x axis to increase cache hits int temp_x, temp_y; _idiv_magic(output_pixel, magic_q, shift_q, Q, temp_y, temp_x); int output_pixel_x = (temp_y & 1) ? (Q - temp_x - 1) : temp_x; int output_pixel_y = temp_y; output_pixel = output_pixel_x + (output_pixel_y * Q); int filter_id = blockIdx.y * blockDim.y; int tid = threadIdx.x + threadIdx.y * blockDim.x; //Offset buffers based on thread id I = &(I[image_id + threadIdx.x]); F = &(F[filter_id + threadIdx.x]); %(filter_load_cond)s //Compute lookup table for filter/image data %(lut_code)s if(tid == 0) { lut_size = lut_size_local; } __syncthreads(); lut_size_local = lut_size; Matrix result[REG_TILE_Y] = {0}; output_pixel = (output_pixel * N) >> 2; if(lut_size_local > 0) { //Evaluate gemm with outer product dimensions N, K and inner product CRS int CRS = lut_size_local * C; //Compute magic numbers for division by lut_size float reciprocal = 1.0f / (float)lut_size_local; //Initialize shared mem for first block int crs = CRS %% NUM_ROWS; crs = (crs == 0) ? 8 : crs; int c, rs; _idiv_fast(CRS - threadIdx.y - 1, lut_size_local, reciprocal, c, rs); int2 lut_entry = ((threadIdx.y & 7) >= crs) ? make_int2(0, 0) : lookup_table[rs]; %(a_name)s_data[threadIdx.y][threadIdx.x].f4 = ((threadIdx.y & 7) >= crs) ? make_float4(0, 0, 0, 0) : I[(c * input_channel_size) + lut_entry.x].f4; %(b_name)s_data[threadIdx.y][threadIdx.x].f4 = %(check_filter_cond)s ((threadIdx.y & 7) >= crs) ? make_float4(0, 0, 0, 0) : F[(c * filter_channel_size) + lut_entry.y].f4; //Iterate over entire filter for(crs = CRS - crs - 1; crs > 0; crs -= NUM_ROWS) { __syncthreads(); #pragma unroll for(int i = 0; i < NUM_ROWS; i++) { Matrix load_row; Matrix load_col; load_row.f4 = %(a_name)s_data[i][threadIdx.x].f4; load_col.f4 = %(b_name)s_data[i][threadIdx.y].f4; //Accumulate product #pragma unroll for(int q_offset = 0; q_offset < REG_TILE_Y; q_offset++) { #pragma unroll for(int p_offset = 0; p_offset < REG_TILE_X; p_offset++) { result[q_offset].f[p_offset] += (load_row.f[p_offset] * load_col.f[q_offset]); } } } __syncthreads(); //Load new image data and filter weights _idiv_fast(crs - threadIdx.y, lut_size_local, reciprocal, c, rs); lut_entry = lookup_table[rs]; %(a_name)s_data[threadIdx.y][threadIdx.x].f4 = I[(c * input_channel_size) + lut_entry.x].f4; %(b_name)s_data[threadIdx.y][threadIdx.x].f4 = %(check_filter_cond)s F[(c * filter_channel_size) + lut_entry.y].f4; } __syncthreads(); //Accumulate product for last iteration #pragma unroll for(int i = 0; i < NUM_ROWS; i++) { Matrix load_row; Matrix load_col; load_row.f4 = %(a_name)s_data[i][threadIdx.x].f4; load_col.f4 = %(b_name)s_data[i][threadIdx.y].f4; //Accumulate product #pragma unroll for(int q_offset = 0; q_offset < REG_TILE_Y; q_offset++) { #pragma unroll for(int p_offset = 0; p_offset < REG_TILE_X; p_offset++) { result[q_offset].f[p_offset] += (load_row.f[p_offset] * load_col.f[q_offset]); } } } } //Store result filter_id = (filter_id + threadIdx.y) << 2; if(filter_id < K) { image_id += threadIdx.x; #pragma unroll for(int q_offset = 0; q_offset < 4; q_offset++) { if(filter_id < K) { int out_index = (filter_id * output_filter_size) + output_pixel + image_id; %(bsum_code)s Matrix cur_value = {0}; if(beta > 0.0f) { cur_value.f4 = O[out_index].f4; } result[q_offset].f[0] = (result[q_offset].f[0] * alpha) + (cur_value.f[0] * beta); result[q_offset].f[1] = (result[q_offset].f[1] * alpha) + (cur_value.f[1] * beta); result[q_offset].f[2] = (result[q_offset].f[2] * alpha) + (cur_value.f[2] * beta); result[q_offset].f[3] = (result[q_offset].f[3] * alpha) + (cur_value.f[3] * beta); O[out_index].f4 = result[q_offset].f4; } filter_id++; } } } """ update_code = r""" { __shared__ Matrix %(a_name)s_data[TILE_DIM / 4][THREADS_DIM * 4 + 4]; __shared__ Matrix %(b_name)s_data[TILE_DIM / 4][THREADS_DIM * 4 + 4]; //TODO: Use square access pattern to image data to increase cache hits int output_pixel, filter_id; _idiv_magic(blockIdx.x, magic_pq, shift_pq, output_pixels, filter_id, output_pixel); filter_id = filter_id * TILE_DIM; int load_filter_id = filter_id + threadIdx.y; int filter_pixel_id = blockIdx.y * TILE_DIM; //TODO: Zig zag along x axis to increase cache hits int output_pixel_x, output_pixel_y; _idiv_magic(output_pixel, magic_q, shift_q, grid_q, output_pixel_y, output_pixel_x); //Compute input image and filter offsets for this pixel int c, rs; int crs = filter_pixel_id + threadIdx.y; _idiv_magic(crs, MAGIC_FILTER_SIZE, SHIFT_FILTER_SIZE, FILTER_SIZE, c, rs); int filter_x, filter_y; _idiv_magic32(rs, magic_s, shift_s, S, filter_y, filter_x); int output_pixel_x_save = output_pixel_x; for(; output_pixel_y < P; output_pixel_y += grid_p) { for(output_pixel_x = output_pixel_x_save; output_pixel_x < Q; output_pixel_x += grid_q) { int base_x = output_pixel_x * stride_w - padding_w + filter_x * dilation_w; int base_y = output_pixel_y * stride_h - padding_h + filter_y * dilation_h; int crs_in_bounds = (c < C) && (base_x >= 0) && (base_x < W) && (base_y >= 0) && (base_y < H); int input_pixel = W * base_y + base_x; output_pixel = output_pixel_x + (Q * output_pixel_y); //Pre-multiply offset to simplify indexing input_pixel = (input_pixel * N) >> 2; output_pixel = (output_pixel * N) >> 2; //Evaluate gemm with outer product dimensions N, K and inner product CRS Matrix result[ITEMS_PER_THREAD] = {0}; //Load image data and transpose into shared mem //TODO: pad shared memory to avoid bank conflicts Matrix buffer; buffer.f4 = crs_in_bounds ? I[(c * input_channel_size) + input_pixel + threadIdx.x].f4 : make_float4(0, 0, 0, 0); %(a_name)s_data[threadIdx.x][ 0 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[0]; %(a_name)s_data[threadIdx.x][ 8 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[1]; %(a_name)s_data[threadIdx.x][16 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[2]; %(a_name)s_data[threadIdx.x][24 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[3]; //Load error data and transpose into shared mem buffer.f4 = (load_filter_id < K) ? F[(load_filter_id * output_filter_size) + output_pixel + threadIdx.x].f4 : make_float4(0, 0, 0, 0); %(b_name)s_data[threadIdx.x][ 0 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[0]; %(b_name)s_data[threadIdx.x][ 8 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[1]; %(b_name)s_data[threadIdx.x][16 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[2]; %(b_name)s_data[threadIdx.x][24 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[3]; //Iterate over entire minibatch for(int n = threadIdx.x + (TILE_DIM >> 2); n < (N >> 2); n += (TILE_DIM >> 2)) { __syncthreads(); #pragma unroll for(int i = 0; i < (TILE_DIM >> 2); i++) { Matrix row_image; Matrix row_error; row_image.f4 = %(a_name)s_data[i][((threadIdx.y & 3) << 3) | threadIdx.y >> 2].f4; row_error.f4 = %(b_name)s_data[i][((threadIdx.y & 3) << 3) | threadIdx.x].f4; //Accumulate product #pragma unroll for(int q_offset = 0; q_offset < ITEMS_PER_THREAD; q_offset++) { #pragma unroll for(int p_offset = 0; p_offset < ITEMS_PER_THREAD; p_offset++) { result[p_offset].f[q_offset] += (row_image.f[p_offset] * row_error.f[q_offset]); } } } __syncthreads(); //Load image data and transpose into shared mem buffer.f4 = crs_in_bounds ? I[(c * input_channel_size) + input_pixel + n].f4 : make_float4(0, 0, 0, 0); %(a_name)s_data[threadIdx.x][ 0 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[0]; %(a_name)s_data[threadIdx.x][ 8 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[1]; %(a_name)s_data[threadIdx.x][16 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[2]; %(a_name)s_data[threadIdx.x][24 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[3]; //Load error data and transpose into shared mem buffer.f4 = (load_filter_id < K) ? F[(load_filter_id * output_filter_size) + output_pixel + n].f4 : make_float4(0, 0, 0, 0); %(b_name)s_data[threadIdx.x][ 0 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[0]; %(b_name)s_data[threadIdx.x][ 8 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[1]; %(b_name)s_data[threadIdx.x][16 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[2]; %(b_name)s_data[threadIdx.x][24 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[3]; } __syncthreads(); //Accumulate product for last iteration #pragma unroll for(int i = 0; i < (TILE_DIM >> 2); i++) { Matrix row_image; Matrix row_error; row_image.f4 = %(a_name)s_data[i][((threadIdx.y & 3) << 3) | threadIdx.y >> 2].f4; row_error.f4 = %(b_name)s_data[i][((threadIdx.y & 3) << 3) | threadIdx.x].f4; //Accumulate product #pragma unroll for(int q_offset = 0; q_offset < ITEMS_PER_THREAD; q_offset++) { #pragma unroll for(int p_offset = 0; p_offset < ITEMS_PER_THREAD; p_offset++) { result[p_offset].f[q_offset] += (row_image.f[p_offset] * row_error.f[q_offset]); } } } //Reduce result between threads in warp Matrix outbound; int warp_y = threadIdx.y & 3; int warp_id = threadIdx.x + (threadIdx.y << 3); buffer.f4 = (warp_y == 0) ? result[0].f4 : (warp_y == 1) ? result[1].f4 : (warp_y == 2) ? result[2].f4 : result[3].f4; outbound.f4 = (warp_y == 0) ? result[3].f4 : (warp_y == 1) ? result[0].f4 : (warp_y == 2) ? result[1].f4 : result[2].f4; buffer.f[0] += __shfl(outbound.f[0], warp_id + 8); buffer.f[1] += __shfl(outbound.f[1], warp_id + 8); buffer.f[2] += __shfl(outbound.f[2], warp_id + 8); buffer.f[3] += __shfl(outbound.f[3], warp_id + 8); outbound.f4 = (warp_y == 0) ? result[2].f4 : (warp_y == 1) ? result[3].f4 : (warp_y == 2) ? result[0].f4 : result[1].f4; buffer.f[0] += __shfl(outbound.f[0], warp_id + 16); buffer.f[1] += __shfl(outbound.f[1], warp_id + 16); buffer.f[2] += __shfl(outbound.f[2], warp_id + 16); buffer.f[3] += __shfl(outbound.f[3], warp_id + 16); outbound.f4 = (warp_y == 0) ? result[1].f4 : (warp_y == 1) ? result[2].f4 : (warp_y == 2) ? result[3].f4 : result[0].f4; buffer.f[0] += __shfl(outbound.f[0], warp_id + 24); buffer.f[1] += __shfl(outbound.f[1], warp_id + 24); buffer.f[2] += __shfl(outbound.f[2], warp_id + 24); buffer.f[3] += __shfl(outbound.f[3], warp_id + 24); //Store result int idx_filter_id = filter_id + (threadIdx.x << 2); if(idx_filter_id < K && crs_in_bounds) { int out_index = (c * filter_channel_size) + (((rs * K) + (idx_filter_id)) >> 2); atomicAdd(&O[out_index].f[0], buffer.f[0]); atomicAdd(&O[out_index].f[1], buffer.f[1]); atomicAdd(&O[out_index].f[2], buffer.f[2]); atomicAdd(&O[out_index].f[3], buffer.f[3]); } } } } """ if operation == "update": code = header_code + update_code else: code = header_code + code magic = _magic64(filter_size) code = code % { "filter_size": filter_size, "magic_filter_size": magic[0], "shift_filter_size": magic[1], "type": _ew_types[dtype]["type"], "lut_code": lut_code, "bsum_code": bsum_code, "operation": operation, "a_name": a_name, "b_name": b_name, "filter_load_cond": filter_load_cond, "check_filter_cond": check_filter_cond } options = ["--use_fast_math"] if debug and operation == "bprop": options = options + ["-g", "-G"] module = SourceModule(code, options=options) kernel = module.get_function("conv_" + operation) kernel.prepare("ffPPPPIIIIIIIIIIIIIIIIIIIIIIIIIIIIII") kernel.name = "conv_" + operation return kernel
def _get_bn_bprop_kernel(dtype, threads, compute_capability): if threads > 32: shr_code = "__shared__ float sPartials[THREADS * 2];" red_code = r""" sPartials[tid + THREADS*0] = grad_gamma; sPartials[tid + THREADS*1] = grad_beta; __syncthreads(); #pragma unroll for (int a = THREADS >> 1; a > 32; a >>= 1) { if ( tid < a ) { sPartials[tid + THREADS*0] += sPartials[tid + a + THREADS*0]; sPartials[tid + THREADS*1] += sPartials[tid + a + THREADS*1]; } __syncthreads(); } if ( tid < 32 ) { grad_gamma = sPartials[tid + THREADS*0] + sPartials[tid + 32 + THREADS*0]; grad_beta = sPartials[tid + THREADS*1] + sPartials[tid + 32 + THREADS*1]; #pragma unroll for (int i = 16; i > 0; i >>= 1) { grad_gamma += __shfl_xor(grad_gamma, i); grad_beta += __shfl_xor(grad_beta, i); } sPartials[tid + THREADS*0] = grad_gamma; sPartials[tid + THREADS*1] = grad_beta; } __syncthreads(); grad_gamma = sPartials[THREADS*0]; grad_beta = sPartials[THREADS*1]; """ else: shr_code = "" red_code = r""" #pragma unroll for (int i = 16; i > 0; i >>= 1) { grad_gamma += __shfl_xor(grad_gamma, i); grad_beta += __shfl_xor(grad_beta, i); } """ code = r""" #define THREADS %(threads)s %(common)s %(binary)s __global__ void batchnorm_bprop ( %(type)s* delta_out, float* grad_gamma_out, float* grad_beta_out, const %(type)s* delta_in, const %(type)s* x_in, const float* xsum_in, const float* xvar_in, const float* gamma_in, const float eps, const int N, bool binary) { %(share)s const int tid = threadIdx.x; const int bid = blockIdx.x; const float rcpN = 1.0f/(float)N; int offset = bid * N; const %(type)s* x_in0 = x_in + offset + tid; const %(type)s* d_in0 = delta_in + offset + tid; float xmean = __ldg(xsum_in + bid) * rcpN; float xvar = __ldg(xvar_in + bid); float gamma = __ldg(gamma_in + bid); float xvar_rcp_sqrt = 1.0f / sqrtf(xvar + eps); float grad_gamma = 0.0f; float grad_beta = 0.0f; for (int i = tid; i < N; i += THREADS) { float x = %(cvt)s(__ldg(x_in0)); x_in0 += THREADS; float d = %(cvt)s(__ldg(d_in0)); d_in0 += THREADS; float xhat = 0.0f; if (binary) { xhat = shift_element(x - xmean, xvar_rcp_sqrt, true); } else { xhat = (x - xmean) * xvar_rcp_sqrt; } grad_gamma += xhat * d; grad_beta += d; } %(red)s if ( tid == 0 ) { *(grad_gamma_out + bid) = grad_gamma; *(grad_beta_out + bid) = grad_beta; } int start = N - (THREADS*4 - tid); offset += start; const %(type)s* x_in1 = x_in + offset; const %(type)s* d_in1 = delta_in + offset; delta_out += offset; for (int i = start; i >= -THREADS*3; i -= THREADS*4) { float x0 = i >= -THREADS*0 ? %(cvt)s(__ldg(x_in1 + THREADS*0)) : 0.0f; float x1 = i >= -THREADS*1 ? %(cvt)s(__ldg(x_in1 + THREADS*1)) : 0.0f; float x2 = i >= -THREADS*2 ? %(cvt)s(__ldg(x_in1 + THREADS*2)) : 0.0f; float x3 = %(cvt)s(__ldg(x_in1 + THREADS*3)); float d0 = i >= -THREADS*0 ? %(cvt)s(__ldg(d_in1 + THREADS*0)) : 0.0f; float d1 = i >= -THREADS*1 ? %(cvt)s(__ldg(d_in1 + THREADS*1)) : 0.0f; float d2 = i >= -THREADS*2 ? %(cvt)s(__ldg(d_in1 + THREADS*2)) : 0.0f; float d3 = %(cvt)s(__ldg(d_in1 + THREADS*3)); x_in1 -= THREADS*4; d_in1 -= THREADS*4; float xhat0 = 0.0f; float xhat1 = 0.0f; float xhat2 = 0.0f; float xhat3 = 0.0f; float xtmp0 = 0.0f; float xtmp1 = 0.0f; float xtmp2 = 0.0f; float xtmp3 = 0.0f; float delta0 = 0.0f; float delta1 = 0.0f; float delta2 = 0.0f; float delta3 = 0.0f; if (binary) { xhat0 = shift_element(x0 - xmean, xvar_rcp_sqrt, true); xhat1 = shift_element(x1 - xmean, xvar_rcp_sqrt, true); xhat2 = shift_element(x2 - xmean, xvar_rcp_sqrt, true); xhat3 = shift_element(x3 - xmean, xvar_rcp_sqrt, true); xtmp0 = (shift_element(xhat0, grad_gamma, true) + grad_beta) * rcpN; xtmp1 = (shift_element(xhat1, grad_gamma, true) + grad_beta) * rcpN; xtmp2 = (shift_element(xhat2, grad_gamma, true) + grad_beta) * rcpN; xtmp3 = (shift_element(xhat3, grad_gamma, true) + grad_beta) * rcpN; delta0 = shift_element(shift_element(d0 - xtmp0, gamma, true), xvar_rcp_sqrt, true); delta1 = shift_element(shift_element(d1 - xtmp1, gamma, true), xvar_rcp_sqrt, true); delta2 = shift_element(shift_element(d2 - xtmp2, gamma, true), xvar_rcp_sqrt, true); delta3 = shift_element(shift_element(d3 - xtmp3, gamma, true), xvar_rcp_sqrt, true); } else { xhat0 = (x0 - xmean) * xvar_rcp_sqrt; xhat1 = (x1 - xmean) * xvar_rcp_sqrt; xhat2 = (x2 - xmean) * xvar_rcp_sqrt; xhat3 = (x3 - xmean) * xvar_rcp_sqrt; xtmp0 = (xhat0 * grad_gamma + grad_beta) * rcpN; xtmp1 = (xhat1 * grad_gamma + grad_beta) * rcpN; xtmp2 = (xhat2 * grad_gamma + grad_beta) * rcpN; xtmp3 = (xhat3 * grad_gamma + grad_beta) * rcpN; delta0 = gamma * (d0 - xtmp0) * xvar_rcp_sqrt; delta1 = gamma * (d1 - xtmp1) * xvar_rcp_sqrt; delta2 = gamma * (d2 - xtmp2) * xvar_rcp_sqrt; delta3 = gamma * (d3 - xtmp3) * xvar_rcp_sqrt; } %(delta0_out)s %(delta1_out)s %(delta2_out)s %(delta3_out)s if (i >= -THREADS*0) *(delta_out + THREADS*0) = delta0_val; if (i >= -THREADS*1) *(delta_out + THREADS*1) = delta1_val; if (i >= -THREADS*2) *(delta_out + THREADS*2) = delta2_val; *(delta_out + THREADS*3) = delta3_val; delta_out -= THREADS*4; } } """ out_code = _ew_strings["round"]["nearest"].get(dtype, "float {0} = {1};") common_code = _common_round["nearest"].get(dtype, "") if dtype == "f2": common_code += _common_fp16_to_fp32 if (compute_capability[0] == 3 and compute_capability[1] < 5) or compute_capability[0] < 3: common_code += _common_kepler code = code % { "common" : common_code, "binary" : shift_element(), "share" : shr_code, "red" : red_code, "threads" : threads, "type" : _ew_types[dtype]["type"], "cvt" : _ew_types[dtype]["cvt"], "delta0_out" : out_code.format("delta0_val", "delta0"), "delta1_out" : out_code.format("delta1_val", "delta1"), "delta2_out" : out_code.format("delta2_val", "delta2"), "delta3_out" : out_code.format("delta3_val", "delta3"), } module = SourceModule(code, options=["--use_fast_math"]) kernel = module.get_function("batchnorm_bprop") kernel.prepare("PPPPPPPPfII") kernel.name = "batchnorm_bprop" return kernel
def _get_bn_fprop_kernel(dtype, threads, compute_capability): if threads > 32: shr_code = "__shared__ float sPartials[THREADS];" red_code = r""" sPartials[tid] = xvar; __syncthreads(); #pragma unroll for (int a = THREADS >> 1; a > 32; a >>= 1) { if ( tid < a ) sPartials[tid] += sPartials[tid + a]; __syncthreads(); } if ( tid < 32 ) { xvar = sPartials[tid] + sPartials[tid + 32]; #pragma unroll for (int i = 16; i > 0; i >>= 1) xvar += __shfl_xor(xvar, i); sPartials[tid] = xvar * rcpN; } __syncthreads(); xvar = sPartials[0]; """ else: shr_code = "" red_code = r""" #pragma unroll for (int i = 16; i > 0; i >>= 1) xvar += __shfl_xor(xvar, i); xvar *= rcpN; """ code = r""" #define THREADS %(threads)s %(common)s %(binary)s __global__ void batchnorm_fprop ( %(type)s* y_out, float* xvar_out, float* gmean_out, float* gvar_out, const %(type)s* x_in, const float* xsum_in, const float* gmean_in, const float* gvar_in, const float* gamma_in, const float* beta_in, const float eps, const float rho, const float accumbeta, const int N, const int relu, bool binary) { %(share)s const int tid = threadIdx.x; const int bid = blockIdx.x; int offset = bid * N; const %(type)s* x_in0 = x_in + offset + tid; const float rcpN = 1.0f/(float)N; float xmean = __ldg(xsum_in + bid) * rcpN; float xvar = 0.0f; for (int i = tid; i < N; i += THREADS) { float x = %(cvt)s(__ldg(x_in0)); x_in0 += THREADS; x -= xmean; if (binary) { xvar += shift_element(x, x, true); } else { xvar += x * x; } } %(red)s float gamma = __ldg(gamma_in + bid); float beta = __ldg(beta_in + bid); if ( tid == 0 ) { float gmean = __ldg(gmean_in + bid); float gvar = __ldg(gvar_in + bid); *(xvar_out + bid) = xvar; *(gmean_out + bid) = gmean * rho + (1.0f - rho) * xmean; *(gvar_out + bid) = gvar * rho + (1.0f - rho) * xvar; } float xvar_rcp_sqrt = 1.0f / sqrtf(xvar + eps); int start = N - (THREADS*4 - tid); offset += start; x_in += offset; y_out += offset; for (int i = start; i >= -THREADS*3; i -= THREADS*4) { float x0 = i >= -THREADS*0 ? %(cvt)s(__ldg(x_in + THREADS*0)) : 0.0f; float x1 = i >= -THREADS*1 ? %(cvt)s(__ldg(x_in + THREADS*1)) : 0.0f; float x2 = i >= -THREADS*2 ? %(cvt)s(__ldg(x_in + THREADS*2)) : 0.0f; float x3 = %(cvt)s(__ldg(x_in + THREADS*3)); x_in -= THREADS*4; float xhat0 = 0.0f; float xhat1 = 0.0f; float xhat2 = 0.0f; float xhat3 = 0.0f; float y0 = 0.0f; float y1 = 0.0f; float y2 = 0.0f; float y3 = 0.0f; if (binary) { xhat0 = shift_element(x0 - xmean, xvar_rcp_sqrt, true); xhat1 = shift_element(x1 - xmean, xvar_rcp_sqrt, true); xhat2 = shift_element(x2 - xmean, xvar_rcp_sqrt, true); xhat3 = shift_element(x3 - xmean, xvar_rcp_sqrt, true); y0 = shift_element(xhat0, gamma, true) + beta; y1 = shift_element(xhat1, gamma, true) + beta; y2 = shift_element(xhat2, gamma, true) + beta; y3 = shift_element(xhat3, gamma, true) + beta; } else { xhat0 = (x0 - xmean) * xvar_rcp_sqrt; xhat1 = (x1 - xmean) * xvar_rcp_sqrt; xhat2 = (x2 - xmean) * xvar_rcp_sqrt; xhat3 = (x3 - xmean) * xvar_rcp_sqrt; y0 = xhat0 * gamma + beta; y1 = xhat1 * gamma + beta; y2 = xhat2 * gamma + beta; y3 = xhat3 * gamma + beta; } if (relu) { y0 = fmaxf(y0, 0.0f); y1 = fmaxf(y1, 0.0f); y2 = fmaxf(y2, 0.0f); y3 = fmaxf(y3, 0.0f); } %(y0_out)s %(y1_out)s %(y2_out)s %(y3_out)s if (accumbeta == 0.0) { if (i >= -THREADS*0) *(y_out + THREADS*0) = y0_val; if (i >= -THREADS*1) *(y_out + THREADS*1) = y1_val; if (i >= -THREADS*2) *(y_out + THREADS*2) = y2_val; *(y_out + THREADS*3) = y3_val; } else { if (i >= -THREADS*0) *(y_out + THREADS*0) = y_out[THREADS*0] * accumbeta + y0_val; if (i >= -THREADS*1) *(y_out + THREADS*1) = y_out[THREADS*1] * accumbeta + y1_val; if (i >= -THREADS*2) *(y_out + THREADS*2) = y_out[THREADS*2] * accumbeta + y2_val; *(y_out + THREADS*3) = y_out[THREADS*3] * accumbeta + y3_val; } y_out -= THREADS*4; } } """ out_code = _ew_strings["round"]["nearest"].get(dtype, "float {0} = {1};") common_code = _common_round["nearest"].get(dtype, "") if dtype == "f2": common_code += _common_fp16_to_fp32 if (compute_capability[0] == 3 and compute_capability[1] < 5) or compute_capability[0] < 3: common_code += _common_kepler code = code % { "common": common_code, "binary": shift_element(), "share": shr_code, "red": red_code, "threads": threads, "type": _ew_types[dtype]["type"], "cvt": _ew_types[dtype]["cvt"], "y0_out": out_code.format("y0_val", "y0"), "y1_out": out_code.format("y1_val", "y1"), "y2_out": out_code.format("y2_val", "y2"), "y3_out": out_code.format("y3_val", "y3"), } module = SourceModule(code, options=["--use_fast_math"]) kernel = module.get_function("batchnorm_fprop") kernel.prepare("PPPPPPPPPPfffIII") kernel.name = "batchnorm_fprop" return kernel
def _get_bn_bprop_kernel(dtype, threads, compute_capability): if threads > 32: shr_code = "__shared__ float sPartials[THREADS * 2];" red_code = r""" sPartials[tid + THREADS*0] = grad_gamma; sPartials[tid + THREADS*1] = grad_beta; __syncthreads(); #pragma unroll for (int a = THREADS >> 1; a > 32; a >>= 1) { if ( tid < a ) { sPartials[tid + THREADS*0] += sPartials[tid + a + THREADS*0]; sPartials[tid + THREADS*1] += sPartials[tid + a + THREADS*1]; } __syncthreads(); } if ( tid < 32 ) { grad_gamma = sPartials[tid + THREADS*0] + sPartials[tid + 32 + THREADS*0]; grad_beta = sPartials[tid + THREADS*1] + sPartials[tid + 32 + THREADS*1]; #pragma unroll for (int i = 16; i > 0; i >>= 1) { grad_gamma += __shfl_xor(grad_gamma, i); grad_beta += __shfl_xor(grad_beta, i); } sPartials[tid + THREADS*0] = grad_gamma; sPartials[tid + THREADS*1] = grad_beta; } __syncthreads(); grad_gamma = sPartials[THREADS*0]; grad_beta = sPartials[THREADS*1]; """ else: shr_code = "" red_code = r""" #pragma unroll for (int i = 16; i > 0; i >>= 1) { grad_gamma += __shfl_xor(grad_gamma, i); grad_beta += __shfl_xor(grad_beta, i); } """ code = r""" #define THREADS %(threads)s %(common)s %(binary)s __global__ void batchnorm_bprop ( %(type)s* delta_out, float* grad_gamma_out, float* grad_beta_out, const %(type)s* delta_in, const %(type)s* x_in, const float* xsum_in, const float* xvar_in, const float* gamma_in, const float eps, const int N, bool binary) { %(share)s const int tid = threadIdx.x; const int bid = blockIdx.x; const float rcpN = 1.0f/(float)N; int offset = bid * N; const %(type)s* x_in0 = x_in + offset + tid; const %(type)s* d_in0 = delta_in + offset + tid; float xmean = __ldg(xsum_in + bid) * rcpN; float xvar = __ldg(xvar_in + bid); float gamma = __ldg(gamma_in + bid); float xvar_rcp_sqrt = 1.0f / sqrtf(xvar + eps); float grad_gamma = 0.0f; float grad_beta = 0.0f; for (int i = tid; i < N; i += THREADS) { float x = %(cvt)s(__ldg(x_in0)); x_in0 += THREADS; float d = %(cvt)s(__ldg(d_in0)); d_in0 += THREADS; float xhat = 0.0f; if (binary) { xhat = shift_element(x - xmean, xvar_rcp_sqrt, true); } else { xhat = (x - xmean) * xvar_rcp_sqrt; } grad_gamma += xhat * d; grad_beta += d; } %(red)s if ( tid == 0 ) { *(grad_gamma_out + bid) = grad_gamma; *(grad_beta_out + bid) = grad_beta; } int start = N - (THREADS*4 - tid); offset += start; const %(type)s* x_in1 = x_in + offset; const %(type)s* d_in1 = delta_in + offset; delta_out += offset; for (int i = start; i >= -THREADS*3; i -= THREADS*4) { float x0 = i >= -THREADS*0 ? %(cvt)s(__ldg(x_in1 + THREADS*0)) : 0.0f; float x1 = i >= -THREADS*1 ? %(cvt)s(__ldg(x_in1 + THREADS*1)) : 0.0f; float x2 = i >= -THREADS*2 ? %(cvt)s(__ldg(x_in1 + THREADS*2)) : 0.0f; float x3 = %(cvt)s(__ldg(x_in1 + THREADS*3)); float d0 = i >= -THREADS*0 ? %(cvt)s(__ldg(d_in1 + THREADS*0)) : 0.0f; float d1 = i >= -THREADS*1 ? %(cvt)s(__ldg(d_in1 + THREADS*1)) : 0.0f; float d2 = i >= -THREADS*2 ? %(cvt)s(__ldg(d_in1 + THREADS*2)) : 0.0f; float d3 = %(cvt)s(__ldg(d_in1 + THREADS*3)); x_in1 -= THREADS*4; d_in1 -= THREADS*4; float xhat0 = 0.0f; float xhat1 = 0.0f; float xhat2 = 0.0f; float xhat3 = 0.0f; float xtmp0 = 0.0f; float xtmp1 = 0.0f; float xtmp2 = 0.0f; float xtmp3 = 0.0f; float delta0 = 0.0f; float delta1 = 0.0f; float delta2 = 0.0f; float delta3 = 0.0f; if (binary) { xhat0 = shift_element(x0 - xmean, xvar_rcp_sqrt, true); xhat1 = shift_element(x1 - xmean, xvar_rcp_sqrt, true); xhat2 = shift_element(x2 - xmean, xvar_rcp_sqrt, true); xhat3 = shift_element(x3 - xmean, xvar_rcp_sqrt, true); xtmp0 = (shift_element(xhat0, grad_gamma, true) + grad_beta) * rcpN; xtmp1 = (shift_element(xhat1, grad_gamma, true) + grad_beta) * rcpN; xtmp2 = (shift_element(xhat2, grad_gamma, true) + grad_beta) * rcpN; xtmp3 = (shift_element(xhat3, grad_gamma, true) + grad_beta) * rcpN; delta0 = shift_element(shift_element(d0 - xtmp0, gamma, true), xvar_rcp_sqrt, true); delta1 = shift_element(shift_element(d1 - xtmp1, gamma, true), xvar_rcp_sqrt, true); delta2 = shift_element(shift_element(d2 - xtmp2, gamma, true), xvar_rcp_sqrt, true); delta3 = shift_element(shift_element(d3 - xtmp3, gamma, true), xvar_rcp_sqrt, true); } else { xhat0 = (x0 - xmean) * xvar_rcp_sqrt; xhat1 = (x1 - xmean) * xvar_rcp_sqrt; xhat2 = (x2 - xmean) * xvar_rcp_sqrt; xhat3 = (x3 - xmean) * xvar_rcp_sqrt; xtmp0 = (xhat0 * grad_gamma + grad_beta) * rcpN; xtmp1 = (xhat1 * grad_gamma + grad_beta) * rcpN; xtmp2 = (xhat2 * grad_gamma + grad_beta) * rcpN; xtmp3 = (xhat3 * grad_gamma + grad_beta) * rcpN; delta0 = gamma * (d0 - xtmp0) * xvar_rcp_sqrt; delta1 = gamma * (d1 - xtmp1) * xvar_rcp_sqrt; delta2 = gamma * (d2 - xtmp2) * xvar_rcp_sqrt; delta3 = gamma * (d3 - xtmp3) * xvar_rcp_sqrt; } %(delta0_out)s %(delta1_out)s %(delta2_out)s %(delta3_out)s if (i >= -THREADS*0) *(delta_out + THREADS*0) = delta0_val; if (i >= -THREADS*1) *(delta_out + THREADS*1) = delta1_val; if (i >= -THREADS*2) *(delta_out + THREADS*2) = delta2_val; *(delta_out + THREADS*3) = delta3_val; delta_out -= THREADS*4; } } """ out_code = _ew_strings["round"]["nearest"].get(dtype, "float {0} = {1};") common_code = _common_round["nearest"].get(dtype, "") if dtype == "f2": common_code += _common_fp16_to_fp32 if (compute_capability[0] == 3 and compute_capability[1] < 5) or compute_capability[0] < 3: common_code += _common_kepler code = code % { "common": common_code, "binary": shift_element(), "share": shr_code, "red": red_code, "threads": threads, "type": _ew_types[dtype]["type"], "cvt": _ew_types[dtype]["cvt"], "delta0_out": out_code.format("delta0_val", "delta0"), "delta1_out": out_code.format("delta1_val", "delta1"), "delta2_out": out_code.format("delta2_val", "delta2"), "delta3_out": out_code.format("delta3_val", "delta3"), } module = SourceModule(code, options=["--use_fast_math"]) kernel = module.get_function("batchnorm_bprop") kernel.prepare("PPPPPPPPfII") kernel.name = "batchnorm_bprop" return kernel
def _get_conv_kernel(dtype, filter_size, bsum, operation, filter_bounds_check=False, debug=False): """ Builds the convolution kernel for a specified filter size. Arguments: dtype (np.dtype): The data type which the kernel will operate on. filter_size (int): Total number of elements per filter (R * S) bsum (boolean): If set to true, kernel will include code to compute batch sum during fprop operation (string): Determines which kernel to build. options follow: 'fprop': Forward propagation of activations. 'bprop': Backward propagation of error. 'update': Computes gradients for filter weights based on error and inputs. filter_bounds_check (boolean): Checks if filter weight is in bounds when K is not a multiple of 32. debug (boolean): When set to true, kernels will be compiled with debug symbols. """ assert operation in ["fprop", "bprop", "update"] if operation == "fprop" or operation == "update": lut_code = r""" if(tid < 32) { int rs = tid; int base_x, base_y; base_x = output_pixel_x * stride_w - padding_w; base_y = output_pixel_y * stride_h - padding_h; unsigned int mask = (1 << tid) - 1; while(rs < FILTER_SIZE) { int filter_x, filter_y; _idiv_magic32(rs, magic_s, shift_s, S, filter_y, filter_x); int index_x = base_x + filter_x; int index_y = base_y + filter_y; //Check if the index is valid int in_bounds = (index_x >= 0 && index_x < W && index_y >= 0 && index_y < H); unsigned int threads_in_bounds = __ballot(in_bounds); //Store lookup table entry if(in_bounds) { int2 lut_entry; lut_entry.x = ((index_y * W + index_x) * N) >> 2; lut_entry.y = (rs * K) >> 2; int index = lut_size_local + __popc(threads_in_bounds & mask); lookup_table[index] = lut_entry; } lut_size_local += __popc(threads_in_bounds); rs += 32; } } """ elif operation == "bprop": lut_code = r""" if(tid < 32) { int rs = tid; int base_q, base_p; base_q = output_pixel_x - (S - padding_w - 1); base_p = output_pixel_y - (R - padding_h - 1); unsigned int mask = (1 << tid) - 1; while(rs < FILTER_SIZE) { int filter_x, filter_y; _idiv_magic32(rs, magic_s, shift_s, S, filter_y, filter_x); int index_q = base_q + filter_x; int index_p = base_p + filter_y; //Check if the index is valid int in_bounds = (((index_q % stride_w) | (index_p % stride_h)) == 0); index_q /= stride_w; index_p /= stride_h; in_bounds = in_bounds && (index_q >= 0 && index_q < W && index_p >= 0 && index_p < H); unsigned int threads_in_bounds = __ballot(in_bounds); //Store lookup table entry if(in_bounds) { int2 lut_entry; lut_entry.x = (((index_p * W) + index_q) * N) >> 2; lut_entry.y = (rs * K) >> 2; int index = lut_size_local + __popc(threads_in_bounds & mask); lookup_table[index] = lut_entry; } lut_size_local += __popc(threads_in_bounds); rs += 32; } } """ if bsum: bsum_code = r""" float local_bsum = result[q_offset].f[0] + result[q_offset].f[1] + result[q_offset].f[2] + result[q_offset].f[3]; atomicAdd(&bsum[filter_id], local_bsum); """ else: bsum_code = "" if operation == "update": a_name = "image" b_name = "error" else: if operation == "fprop": a_name = "image" b_name = "filter" elif operation == "bprop": a_name = "error" b_name = "filter" if filter_bounds_check: filter_load_cond = "int filter_load_in_bounds = (((filter_id + threadIdx.x) << 2) < K);" check_filter_cond = "(!filter_load_in_bounds) ? make_float4(0, 0, 0, 0) :" else: filter_load_cond = "" check_filter_cond = "" header_code = r""" #define TILE_DIM 32 #define ITEMS_PER_THREAD 4 #define THREADS_DIM 8 #define REG_TILE_X 4 #define REG_TILE_Y 4 #define THREADS_DIM_X 8 #define THREADS_DIM_Y 8 #define SM_TILE_X (REG_TILE_X * THREADS_DIM_X) #define SM_TILE_Y (REG_TILE_Y * THREADS_DIM_Y) #define NUM_ROWS 8 #define FILTER_SIZE %(filter_size)s #define MAGIC_FILTER_SIZE %(magic_filter_size)s #define SHIFT_FILTER_SIZE %(shift_filter_size)s typedef union Matrix { %(type)s4 f4; %(type)s f[4]; } Matrix; __device__ inline void _idiv_fast(int numerator, int denominator, float rcp, int& result, int& remainder) { result = (int)((float)numerator * rcp); remainder = numerator - (result * denominator); result = (remainder >= denominator) ? (result + 1) : result; remainder = (remainder >= denominator) ? (remainder - denominator) : remainder; } __device__ inline void _idiv_magic(int numerator, unsigned int magic, unsigned int shift, int denominator, int& result, int& remainder) { if(magic == 1) { result = numerator >> shift; } else { unsigned long long res64 = numerator * (unsigned long long)magic; result = ((int)(res64 >> 32) >> shift); } remainder = numerator - (result * denominator); } __device__ inline void _idiv_magic32(int numerator, unsigned int magic, unsigned int shift, int denominator, int& result, int& remainder) { if(magic == 1) { result = numerator >> shift; } else { result = ((numerator * magic) >> shift); } remainder = numerator - (result * denominator); } //Note: N and K must be multiples of 4 //blockIdx.x is gemm tile id (K dimension) and output pixel id //blockIdx.y is gemm tile id (N dimension) //threadIdx.x is gemm tile offset (K dimension) //threadIdx.y is gemm tile offset (N dimension) __global__ void conv_%(operation)s( %(type)s alpha, %(type)s beta, Matrix *I, Matrix *F, Matrix *O, float* bsum, int C, int D, int H, int W, int N, int T, int R, int S, int K, int M, int P, int Q, int stride_w, int stride_h, int padding_w, int padding_h, int input_channel_size, int filter_channel_size, int output_filter_size, int output_pixels, int grid_p, int grid_q, unsigned int magic_pq, unsigned int shift_pq, unsigned int magic_q, unsigned int shift_q, unsigned int magic_s, unsigned int shift_s) """ code = r""" { __shared__ int2 lookup_table[FILTER_SIZE]; __shared__ int lut_size; __shared__ Matrix %(a_name)s_data[NUM_ROWS][THREADS_DIM_X]; __shared__ Matrix %(b_name)s_data[NUM_ROWS][THREADS_DIM_Y]; int lut_size_local = 0; //TODO: Use square access pattern to image data to increase cache hits int output_pixel, image_id; _idiv_magic(blockIdx.x, magic_pq, shift_pq, output_pixels, image_id, output_pixel); image_id = (image_id * blockDim.x); //Zig zag along x axis to increase cache hits int temp_x, temp_y; _idiv_magic(output_pixel, magic_q, shift_q, Q, temp_y, temp_x); int output_pixel_x = (temp_y & 1) ? (Q - temp_x - 1) : temp_x; int output_pixel_y = temp_y; output_pixel = output_pixel_x + (output_pixel_y * Q); int filter_id = blockIdx.y * blockDim.y; int tid = threadIdx.x + threadIdx.y * blockDim.x; //Offset buffers based on thread id I = &(I[image_id + threadIdx.x]); F = &(F[filter_id + threadIdx.x]); %(filter_load_cond)s //Compute lookup table for filter/image data %(lut_code)s if(tid == 0) { lut_size = lut_size_local; } __syncthreads(); lut_size_local = lut_size; Matrix result[REG_TILE_Y] = {0}; output_pixel = (output_pixel * N) >> 2; if(lut_size_local > 0) { //Evaluate gemm with outer product dimensions N, K and inner product CRS int CRS = lut_size_local * C; //Compute magic numbers for division by lut_size float reciprocal = 1.0f / (float)lut_size_local; //Initialize shared mem for first block int crs = CRS %% NUM_ROWS; crs = (crs == 0) ? 8 : crs; int c, rs; _idiv_fast(CRS - threadIdx.y - 1, lut_size_local, reciprocal, c, rs); int2 lut_entry = ((threadIdx.y & 7) >= crs) ? make_int2(0, 0) : lookup_table[rs]; %(a_name)s_data[threadIdx.y][threadIdx.x].f4 = ((threadIdx.y & 7) >= crs) ? make_float4(0, 0, 0, 0) : I[(c * input_channel_size) + lut_entry.x].f4; %(b_name)s_data[threadIdx.y][threadIdx.x].f4 = %(check_filter_cond)s ((threadIdx.y & 7) >= crs) ? make_float4(0, 0, 0, 0) : F[(c * filter_channel_size) + lut_entry.y].f4; //Iterate over entire filter for(crs = CRS - crs - 1; crs > 0; crs -= NUM_ROWS) { __syncthreads(); #pragma unroll for(int i = 0; i < NUM_ROWS; i++) { Matrix load_row; Matrix load_col; load_row.f4 = %(a_name)s_data[i][threadIdx.x].f4; load_col.f4 = %(b_name)s_data[i][threadIdx.y].f4; //Accumulate product #pragma unroll for(int q_offset = 0; q_offset < REG_TILE_Y; q_offset++) { #pragma unroll for(int p_offset = 0; p_offset < REG_TILE_X; p_offset++) { result[q_offset].f[p_offset] += (load_row.f[p_offset] * load_col.f[q_offset]); } } } __syncthreads(); //Load new image data and filter weights _idiv_fast(crs - threadIdx.y, lut_size_local, reciprocal, c, rs); lut_entry = lookup_table[rs]; %(a_name)s_data[threadIdx.y][threadIdx.x].f4 = I[(c * input_channel_size) + lut_entry.x].f4; %(b_name)s_data[threadIdx.y][threadIdx.x].f4 = %(check_filter_cond)s F[(c * filter_channel_size) + lut_entry.y].f4; } __syncthreads(); //Accumulate product for last iteration #pragma unroll for(int i = 0; i < NUM_ROWS; i++) { Matrix load_row; Matrix load_col; load_row.f4 = %(a_name)s_data[i][threadIdx.x].f4; load_col.f4 = %(b_name)s_data[i][threadIdx.y].f4; //Accumulate product #pragma unroll for(int q_offset = 0; q_offset < REG_TILE_Y; q_offset++) { #pragma unroll for(int p_offset = 0; p_offset < REG_TILE_X; p_offset++) { result[q_offset].f[p_offset] += (load_row.f[p_offset] * load_col.f[q_offset]); } } } } //Store result filter_id = (filter_id + threadIdx.y) << 2; if(filter_id < K) { image_id += threadIdx.x; #pragma unroll for(int q_offset = 0; q_offset < 4; q_offset++) { if(filter_id < K) { int out_index = (filter_id * output_filter_size) + output_pixel + image_id; %(bsum_code)s Matrix cur_value = {0}; if(beta > 0.0f) { cur_value.f4 = O[out_index].f4; } result[q_offset].f[0] = (result[q_offset].f[0] * alpha) + (cur_value.f[0] * beta); result[q_offset].f[1] = (result[q_offset].f[1] * alpha) + (cur_value.f[1] * beta); result[q_offset].f[2] = (result[q_offset].f[2] * alpha) + (cur_value.f[2] * beta); result[q_offset].f[3] = (result[q_offset].f[3] * alpha) + (cur_value.f[3] * beta); O[out_index].f4 = result[q_offset].f4; } filter_id++; } } } """ update_code = r""" { __shared__ Matrix %(a_name)s_data[TILE_DIM / 4][THREADS_DIM * 4 + 4]; __shared__ Matrix %(b_name)s_data[TILE_DIM / 4][THREADS_DIM * 4 + 4]; //TODO: Use square access pattern to image data to increase cache hits int output_pixel, filter_id; _idiv_magic(blockIdx.x, magic_pq, shift_pq, output_pixels, filter_id, output_pixel); filter_id = filter_id * TILE_DIM; int load_filter_id = filter_id + threadIdx.y; int filter_pixel_id = blockIdx.y * TILE_DIM; //TODO: Zig zag along x axis to increase cache hits int output_pixel_x, output_pixel_y; _idiv_magic(output_pixel, magic_q, shift_q, grid_q, output_pixel_y, output_pixel_x); //Compute input image and filter offsets for this pixel int c, rs; int crs = filter_pixel_id + threadIdx.y; _idiv_magic(crs, MAGIC_FILTER_SIZE, SHIFT_FILTER_SIZE, FILTER_SIZE, c, rs); int filter_x, filter_y; _idiv_magic32(rs, magic_s, shift_s, S, filter_y, filter_x); int output_pixel_x_save = output_pixel_x; for(; output_pixel_y < P; output_pixel_y += grid_p) { for(output_pixel_x = output_pixel_x_save; output_pixel_x < Q; output_pixel_x += grid_q) { int base_x = output_pixel_x * stride_w - padding_w + filter_x; int base_y = output_pixel_y * stride_h - padding_h + filter_y; int crs_in_bounds = (c < C) && (base_x >= 0) && (base_x < W) && (base_y >= 0) && (base_y < H); int input_pixel = W * base_y + base_x; output_pixel = output_pixel_x + (Q * output_pixel_y); //Pre-multiply offset to simplify indexing input_pixel = (input_pixel * N) >> 2; output_pixel = (output_pixel * N) >> 2; //Evaluate gemm with outer product dimensions N, K and inner product CRS Matrix result[ITEMS_PER_THREAD] = {0}; //Load image data and transpose into shared mem //TODO: pad shared memory to avoid bank conflicts Matrix buffer; buffer.f4 = crs_in_bounds ? I[(c * input_channel_size) + input_pixel + threadIdx.x].f4 : make_float4(0, 0, 0, 0); %(a_name)s_data[threadIdx.x][ 0 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[0]; %(a_name)s_data[threadIdx.x][ 8 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[1]; %(a_name)s_data[threadIdx.x][16 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[2]; %(a_name)s_data[threadIdx.x][24 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[3]; //Load error data and transpose into shared mem buffer.f4 = (load_filter_id < K) ? F[(load_filter_id * output_filter_size) + output_pixel + threadIdx.x].f4 : make_float4(0, 0, 0, 0); %(b_name)s_data[threadIdx.x][ 0 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[0]; %(b_name)s_data[threadIdx.x][ 8 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[1]; %(b_name)s_data[threadIdx.x][16 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[2]; %(b_name)s_data[threadIdx.x][24 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[3]; //Iterate over entire minibatch for(int n = threadIdx.x + (TILE_DIM >> 2); n < (N >> 2); n += (TILE_DIM >> 2)) { __syncthreads(); #pragma unroll for(int i = 0; i < (TILE_DIM >> 2); i++) { Matrix row_image; Matrix row_error; row_image.f4 = %(a_name)s_data[i][((threadIdx.y & 3) << 3) | threadIdx.y >> 2].f4; row_error.f4 = %(b_name)s_data[i][((threadIdx.y & 3) << 3) | threadIdx.x].f4; //Accumulate product #pragma unroll for(int q_offset = 0; q_offset < ITEMS_PER_THREAD; q_offset++) { #pragma unroll for(int p_offset = 0; p_offset < ITEMS_PER_THREAD; p_offset++) { result[p_offset].f[q_offset] += (row_image.f[p_offset] * row_error.f[q_offset]); } } } __syncthreads(); //Load image data and transpose into shared mem buffer.f4 = crs_in_bounds ? I[(c * input_channel_size) + input_pixel + n].f4 : make_float4(0, 0, 0, 0); %(a_name)s_data[threadIdx.x][ 0 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[0]; %(a_name)s_data[threadIdx.x][ 8 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[1]; %(a_name)s_data[threadIdx.x][16 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[2]; %(a_name)s_data[threadIdx.x][24 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[3]; //Load error data and transpose into shared mem buffer.f4 = (load_filter_id < K) ? F[(load_filter_id * output_filter_size) + output_pixel + n].f4 : make_float4(0, 0, 0, 0); %(b_name)s_data[threadIdx.x][ 0 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[0]; %(b_name)s_data[threadIdx.x][ 8 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[1]; %(b_name)s_data[threadIdx.x][16 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[2]; %(b_name)s_data[threadIdx.x][24 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[3]; } __syncthreads(); //Accumulate product for last iteration #pragma unroll for(int i = 0; i < (TILE_DIM >> 2); i++) { Matrix row_image; Matrix row_error; row_image.f4 = %(a_name)s_data[i][((threadIdx.y & 3) << 3) | threadIdx.y >> 2].f4; row_error.f4 = %(b_name)s_data[i][((threadIdx.y & 3) << 3) | threadIdx.x].f4; //Accumulate product #pragma unroll for(int q_offset = 0; q_offset < ITEMS_PER_THREAD; q_offset++) { #pragma unroll for(int p_offset = 0; p_offset < ITEMS_PER_THREAD; p_offset++) { result[p_offset].f[q_offset] += (row_image.f[p_offset] * row_error.f[q_offset]); } } } //Reduce result between threads in warp Matrix outbound; int warp_y = threadIdx.y & 3; int warp_id = threadIdx.x + (threadIdx.y << 3); buffer.f4 = (warp_y == 0) ? result[0].f4 : (warp_y == 1) ? result[1].f4 : (warp_y == 2) ? result[2].f4 : result[3].f4; outbound.f4 = (warp_y == 0) ? result[3].f4 : (warp_y == 1) ? result[0].f4 : (warp_y == 2) ? result[1].f4 : result[2].f4; buffer.f[0] += __shfl(outbound.f[0], warp_id + 8); buffer.f[1] += __shfl(outbound.f[1], warp_id + 8); buffer.f[2] += __shfl(outbound.f[2], warp_id + 8); buffer.f[3] += __shfl(outbound.f[3], warp_id + 8); outbound.f4 = (warp_y == 0) ? result[2].f4 : (warp_y == 1) ? result[3].f4 : (warp_y == 2) ? result[0].f4 : result[1].f4; buffer.f[0] += __shfl(outbound.f[0], warp_id + 16); buffer.f[1] += __shfl(outbound.f[1], warp_id + 16); buffer.f[2] += __shfl(outbound.f[2], warp_id + 16); buffer.f[3] += __shfl(outbound.f[3], warp_id + 16); outbound.f4 = (warp_y == 0) ? result[1].f4 : (warp_y == 1) ? result[2].f4 : (warp_y == 2) ? result[3].f4 : result[0].f4; buffer.f[0] += __shfl(outbound.f[0], warp_id + 24); buffer.f[1] += __shfl(outbound.f[1], warp_id + 24); buffer.f[2] += __shfl(outbound.f[2], warp_id + 24); buffer.f[3] += __shfl(outbound.f[3], warp_id + 24); //Store result int idx_filter_id = filter_id + (threadIdx.x << 2); if(idx_filter_id < K && crs_in_bounds) { int out_index = (c * filter_channel_size) + (((rs * K) + (idx_filter_id)) >> 2); atomicAdd(&O[out_index].f[0], buffer.f[0]); atomicAdd(&O[out_index].f[1], buffer.f[1]); atomicAdd(&O[out_index].f[2], buffer.f[2]); atomicAdd(&O[out_index].f[3], buffer.f[3]); } } } } """ if operation == "update": code = header_code + update_code else: code = header_code + code magic = _magic64(filter_size) code = code % { "filter_size": filter_size, "magic_filter_size": magic[0], "shift_filter_size": magic[1], "type": _ew_types[dtype]["type"], "lut_code": lut_code, "bsum_code": bsum_code, "operation": operation, "a_name": a_name, "b_name": b_name, "filter_load_cond": filter_load_cond, "check_filter_cond": check_filter_cond } options = ["--use_fast_math"] if debug and operation == "bprop": options = options + ["-g", "-G"] module = SourceModule(code, options=options) kernel = module.get_function("conv_" + operation) kernel.prepare("ffPPPPIIIIIIIIIIIIIIIIIIIIIIIIIIII") kernel.name = "conv_" + operation return kernel
def _get_lut_bprop_kernel(dtype, deterministic=False): """ Builds the bprop kernel for lookup table layers based on templated code. If the deterministic version is requested, an index buffer must be passed as an argument. This index buffer re-orders items in the input tensor so that word_ids are sorted. This is required since we need to be sure that each thread only updates weights for one word id. Arguments: dtype (np.dtype): The data which the kernel will operate on. deterministic (boolean): Builds the deterministic kernel when this is set to True. """ if not deterministic: code = r""" __global__ void lut_bprop( int* inputs, %(type)s* dW, %(type)s* errors, const int nin, const int embedding_dim, const int vocab_size, const int pad_idx) { const int tid = threadIdx.x; const int bid = blockIdx.x; int word_id = inputs[bid]; int error_row = bid * embedding_dim; int output_row = word_id * embedding_dim; if(word_id != pad_idx) { for(int i = tid; i < embedding_dim; i += blockDim.x) { atomicAdd(&dW[output_row + i], errors[error_row + i]); } } } """ code = code % { "type": _ew_types[dtype]["type"] } module = SourceModule(code, options=["--use_fast_math"]) kernel = module.get_function("lut_bprop") kernel.prepare("PPPIIIi") else: code = r""" __global__ void lut_bprop( int* inputs, int* index_buffer, %(type)s* dW, %(type)s* errors, const int nin, const int embedding_dim, const int vocab_size, const int pad_idx) { const int tid = threadIdx.x; const int bid = blockIdx.x; int index_position = bid; int index = index_buffer[index_position]; int word_id = inputs[index]; if((bid == 0 || word_id != inputs[index_buffer[bid - 1]]) && word_id != pad_idx) { int output_row = word_id * embedding_dim; do { int error_row = index * embedding_dim; for(int i = tid; i < embedding_dim; i += blockDim.x) { dW[output_row + i] += errors[error_row + i]; } index_position++; if(index_position == gridDim.x) { break; } index = index_buffer[index_position]; } while(inputs[index] == word_id); } } """ code = code % { "type": _ew_types[dtype]["type"] } module = SourceModule(code, options=["--use_fast_math"]) kernel = module.get_function("lut_bprop") kernel.prepare("PPPPIIIi") kernel.name = "lut_bprop" return kernel
def _get_nms_kernel(): code = r""" #define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0)) int const threadsPerBlock = sizeof(unsigned int) * 8; __device__ inline float devIoU(float const * const a, float const * const b) { float left = max(a[0], b[0]), right = min(a[2], b[2]); float top = max(a[1], b[1]), bottom = min(a[3], b[3]); float width = max(right - left + 1, 0.f), height = max(bottom - top + 1, 0.f); float interS = width * height; float Sa = (a[2] - a[0] + 1) * (a[3] - a[1] + 1); float Sb = (b[2] - b[0] + 1) * (b[3] - b[1] + 1); return interS / (Sa + Sb - interS); } __global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh, const float *dev_boxes, unsigned int *dev_mask) { const int row_start = blockIdx.y; const int col_start = blockIdx.x; // if (row_start > col_start) return; const int row_size = min(n_boxes - row_start * threadsPerBlock, threadsPerBlock); const int col_size = min(n_boxes - col_start * threadsPerBlock, threadsPerBlock); __shared__ float block_boxes[threadsPerBlock * 5]; if (threadIdx.x < col_size) { block_boxes[threadIdx.x * 5 + 0] = dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 0]; block_boxes[threadIdx.x * 5 + 1] = dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 1]; block_boxes[threadIdx.x * 5 + 2] = dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 2]; block_boxes[threadIdx.x * 5 + 3] = dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 3]; block_boxes[threadIdx.x * 5 + 4] = dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 4]; } __syncthreads(); if (threadIdx.x < row_size) { const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x; const float *cur_box = dev_boxes + cur_box_idx * 5; int i = 0; unsigned int t = 0; int start = 0; if (row_start == col_start) { start = threadIdx.x + 1; } for (i = start; i < col_size; i++) { if (devIoU(cur_box, block_boxes + i * 5) > nms_overlap_thresh) { t |= 1UL << i; } } const int col_blocks = DIVUP(n_boxes, threadsPerBlock); dev_mask[cur_box_idx * col_blocks + col_start] = t; } } """ module = SourceModule(code) kernel = module.get_function("nms_kernel") sig = "1I 1f 2P" kernel.prepare(sig) return kernel
def _get_lut_bprop_kernel(dtype, deterministic=False): """ Builds the bprop kernel for lookup table layers based on templated code. If the deterministic version is requested, an index buffer must be passed as an argument. This index buffer re-orders items in the input tensor so that word_ids are sorted. This is required since we need to be sure that each thread only updates weights for one word id. Arguments: dtype (np.dtype): The data which the kernel will operate on. deterministic (boolean): Builds the deterministic kernel when this is set to True. """ if not deterministic: code = r""" __global__ void lut_bprop( int* inputs, %(type)s* dW, %(type)s* errors, const int nin, const int embedding_dim, const int vocab_size, const int pad_idx) { const int tid = threadIdx.x; const int bid = blockIdx.x; int word_id = inputs[bid]; int error_row = bid * embedding_dim; int output_row = word_id * embedding_dim; if(word_id != pad_idx) { for(int i = tid; i < embedding_dim; i += blockDim.x) { atomicAdd(&dW[output_row + i], errors[error_row + i]); } } } """ code = code % {"type": _ew_types[dtype]["type"]} module = SourceModule(code, options=["--use_fast_math"]) kernel = module.get_function("lut_bprop") kernel.prepare("PPPIIIi") else: code = r""" __global__ void lut_bprop( int* inputs, int* index_buffer, %(type)s* dW, %(type)s* errors, const int nin, const int embedding_dim, const int vocab_size, const int pad_idx) { const int tid = threadIdx.x; const int bid = blockIdx.x; int index_position = bid; int index = index_buffer[index_position]; int word_id = inputs[index]; if((bid == 0 || word_id != inputs[index_buffer[bid - 1]]) && word_id != pad_idx) { int output_row = word_id * embedding_dim; do { int error_row = index * embedding_dim; for(int i = tid; i < embedding_dim; i += blockDim.x) { dW[output_row + i] += errors[error_row + i]; } index_position++; if(index_position == gridDim.x) { break; } index = index_buffer[index_position]; } while(inputs[index] == word_id); } } """ code = code % {"type": _ew_types[dtype]["type"]} module = SourceModule(code, options=["--use_fast_math"]) kernel = module.get_function("lut_bprop") kernel.prepare("PPPPIIIi") kernel.name = "lut_bprop" return kernel