def __init__(self, lib, dtype, N, C, K, D, H, W, T, R, S, M, P, Q, pad_d, pad_h, pad_w, str_d, str_h, str_w, bsum): super(FpropCuda, self).__init__(lib, dtype) assert N % 32 == 0, "N dim must be multiple of 32" assert K % self.vec_size == 0, "K dim must be multiple of %d" % self.vec_size magic_PQ = magic64(P*Q) magic_Q = magic64(Q) magic_S = magic32(R*S+32, S) HWN = H * W * N RST = R * S * T KRST = K * RST PQ = P * Q PQN = PQ * N self.kernel = _get_conv_kernel(dtype=self.dtype.str[1:], filter_size=R*S, bsum=bsum, operation="fprop") grid = (PQ * (-(-N // 32)), (-(-K // 32)), 1) block = (8, 8, 1) static_kernel_args = _flatten([C, D, H, W, N, T, R, S, K, M, P, Q, str_w, str_h, pad_w, pad_h, HWN // 4, KRST // 4, PQN // 4, PQ, 0, 0, magic_PQ, magic_Q, magic_S]) self.launch_args = [grid, block] + [None] * 7 + static_kernel_args self.shared = RST * 4 * 2 self.flags = (bsum and 4)
def __init__(self, lib, dtype, N, C, K, D, H, W, T, R, S, M, P, Q, pad_d, pad_h, pad_w, str_d, str_h, str_w, bsum): super(BpropCuda, self).__init__(lib, dtype) assert N % 32 == 0, "N dim must be multiple of 32" assert K % self.vec_size == 0, "K dim must be multiple of %d" % self.vec_size magic_HW = magic64(H*W) magic_W = magic64(W) magic_RS = magic32(R*S*T+32, R*S) magic_S = magic32(R*S+32, S) HW = H * W HWN = HW * N RST = R * S * T CRST = C * RST PQ = P * Q PQN = PQ * N self.bsum = bsum self.kernel = _get_conv_kernel(dtype=self.dtype.str[1:], filter_size=R*S, bsum=bsum, operation="bprop") grid = (HW * (-(-N // 32)), -(-C // 32), 1) block = (8, 8, 1) static_kernel_args = _flatten([K, M, P, Q, N, T, R, S, C, D, H, W, str_w, str_h, pad_w, pad_h, PQN // 4, CRST // 4, HWN // 4, HW, 0, 0, magic_HW, magic_W, magic_S]) self.launch_args = [grid, block] + [None] * 7 + static_kernel_args self.shared = R*S*T * 4 * 2 self.flags = (bsum and 4) # generate the kernel args for dim shuffling CTRSK => KTRSC shuffle_grid = (ceil_div(K, 32), ceil_div(C, 32), R*S*T) self.shuffle_size = C*T*R*S*K*dtype.itemsize self.shuffle_args = [shuffle_grid, (32, 8, 1), None, None, None] self.shuffle_args.extend(_flatten([ R*S*T*K, R*S*K, S*K, K, R*S*T*C, R*S*C, S*C, C, R*S, T, R, S, magic_RS, magic_S])) lib.set_scratch_size(self.shuffle_size)
def __init__(self, lib, dtype, N, C, K, D, H, W, T, R, S, M, P, Q, pad_d, pad_h, pad_w, str_d, str_h, str_w): super(UpdateCuda, self).__init__(lib, dtype) assert N % 32 == 0, "N dim must be multiple of 32" HWN = H * W * N RS = R * S RST = RS * T KRST = K * RST CRSTK = KRST * C PQ = P * Q PQN = PQ * N magic_S = magic32(R*S+32, S) if lib.deterministic: grid_P = 1 grid_Q = 1 self.determ = CRSTK else: grid_P = P grid_Q = Q self.determ = 0 pq_blocks = grid_P * grid_Q magic_PQ = magic64(pq_blocks) magic_Q = magic64(grid_Q) self.kernel = _get_conv_kernel(dtype=self.dtype.str[1:], filter_size=R*S, bsum=False, operation="update") grid = (pq_blocks * (-(-K // 32)), (-(-(C*RS) // 32)), 1) block = (8, 32, 1) static_kernel_args = _flatten([C, D, H, W, N, T, R, S, K, M, P, Q, str_w, str_h, pad_w, pad_h, HWN // 4, KRST // 4, PQN // 4, PQ, grid_P, grid_Q, magic_PQ, magic_Q, magic_S]) self.launch_args = [grid, block] + [None] * 7 + static_kernel_args lib.set_scratch_size((self.determ or C*T*R*S*K)*4)
def _get_conv_kernel(dtype, filter_size, bsum, operation, filter_bounds_check=False, debug=False): """ Builds the convolution kernel for a specified filter size. Arguments: dtype (np.dtype): The data type which the kernel will operate on. filter_size (int): Total number of elements per filter (R * S) bsum (boolean): If set to true, kernel will include code to compute batch sum during fprop operation (string): Determines which kernel to build. options follow: 'fprop': Forward propagation of activations. 'bprop': Backward propagation of error. 'update': Computes gradients for filter weights based on error and inputs. filter_bounds_check (boolean): Checks if filter weight is in bounds when K is not a multiple of 32. debug (boolean): When set to true, kernels will be compiled with debug symbols. """ print('_get_conv_kernel dtype', dtype, 'filter_size', filter_size, 'bsum', bsum, 'operation', operation, 'filter_bounds_check', filter_bounds_check, 'debug', debug) assert operation in ["fprop", "bprop", "update"] if operation == "fprop" or operation == "update": lut_code = r""" if(tid < 32) { int rs = tid; int base_x, base_y; base_x = output_pixel_x * stride_w - padding_w; base_y = output_pixel_y * stride_h - padding_h; unsigned int mask = (1 << tid) - 1; while(rs < FILTER_SIZE) { int filter_x, filter_y; _idiv_magic32(rs, magic_s, shift_s, S, filter_y, filter_x); int index_x = base_x + filter_x; int index_y = base_y + filter_y; //Check if the index is valid int in_bounds = (index_x >= 0 && index_x < W && index_y >= 0 && index_y < H); unsigned int threads_in_bounds = __ballot(in_bounds); //Store lookup table entry if(in_bounds) { int2 lut_entry; lut_entry.x = ((index_y * W + index_x) * N) >> 2; lut_entry.y = (rs * K) >> 2; int index = lut_size_local + __popc(threads_in_bounds & mask); lookup_table[index] = lut_entry; } lut_size_local += __popc(threads_in_bounds); rs += 32; } } """ elif operation == "bprop": lut_code = r""" if(tid < 32) { int rs = tid; int base_q, base_p; base_q = output_pixel_x - (S - padding_w - 1); base_p = output_pixel_y - (R - padding_h - 1); unsigned int mask = (1 << tid) - 1; while(rs < FILTER_SIZE) { int filter_x, filter_y; _idiv_magic32(rs, magic_s, shift_s, S, filter_y, filter_x); int index_q = base_q + filter_x; int index_p = base_p + filter_y; //Check if the index is valid int in_bounds = (((index_q % stride_w) | (index_p % stride_h)) == 0); index_q /= stride_w; index_p /= stride_h; in_bounds = in_bounds && (index_q >= 0 && index_q < W && index_p >= 0 && index_p < H); unsigned int threads_in_bounds = __ballot(in_bounds); //Store lookup table entry if(in_bounds) { int2 lut_entry; lut_entry.x = (((index_p * W) + index_q) * N) >> 2; lut_entry.y = (rs * K) >> 2; int index = lut_size_local + __popc(threads_in_bounds & mask); lookup_table[index] = lut_entry; } lut_size_local += __popc(threads_in_bounds); rs += 32; } } """ if bsum: bsum_code = r""" float local_bsum = result[q_offset].f[0] + result[q_offset].f[1] + result[q_offset].f[2] + result[q_offset].f[3]; atomicAdd(&bsum[filter_id], local_bsum); """ else: bsum_code = "" if operation == "update": a_name = "image" b_name = "error" else: if operation == "fprop": a_name = "image" b_name = "filter" elif operation == "bprop": a_name = "error" b_name = "filter" if filter_bounds_check: filter_load_cond = "int filter_load_in_bounds = (((filter_id + threadIdx.x) << 2) < K);" check_filter_cond = "(!filter_load_in_bounds) ? make_float4(0, 0, 0, 0) :" else: filter_load_cond = "" check_filter_cond = "" header_code = r""" #define TILE_DIM 32 #define ITEMS_PER_THREAD 4 #define THREADS_DIM 8 #define REG_TILE_X 4 #define REG_TILE_Y 4 #define THREADS_DIM_X 8 #define THREADS_DIM_Y 8 #define SM_TILE_X (REG_TILE_X * THREADS_DIM_X) #define SM_TILE_Y (REG_TILE_Y * THREADS_DIM_Y) #define NUM_ROWS 8 #define FILTER_SIZE %(filter_size)s #define MAGIC_FILTER_SIZE %(magic_filter_size)s #define SHIFT_FILTER_SIZE %(shift_filter_size)s typedef union Matrix { %(type)s4 f4; %(type)s f[4]; } Matrix; __device__ inline void _idiv_fast(int numerator, int denominator, float rcp, int& result, int& remainder) { result = (int)((float)numerator * rcp); remainder = numerator - (result * denominator); result = (remainder >= denominator) ? (result + 1) : result; remainder = (remainder >= denominator) ? (remainder - denominator) : remainder; } __device__ inline void _idiv_magic(int numerator, unsigned int magic, unsigned int shift, int denominator, int& result, int& remainder) { if(magic == 1) { result = numerator >> shift; } else { unsigned long long res64 = numerator * (unsigned long long)magic; result = ((int)(res64 >> 32) >> shift); } remainder = numerator - (result * denominator); } __device__ inline void _idiv_magic32(int numerator, unsigned int magic, unsigned int shift, int denominator, int& result, int& remainder) { if(magic == 1) { result = numerator >> shift; } else { result = ((numerator * magic) >> shift); } remainder = numerator - (result * denominator); } //Note: N and K must be multiples of 4 //blockIdx.x is gemm tile id (K dimension) and output pixel id //blockIdx.y is gemm tile id (N dimension) //threadIdx.x is gemm tile offset (K dimension) //threadIdx.y is gemm tile offset (N dimension) __global__ void conv_%(operation)s( %(type)s alpha, %(type)s beta, Matrix *I, Matrix *F, Matrix *O, float* bsum, int C, int D, int H, int W, int N, int T, int R, int S, int K, int M, int P, int Q, int stride_w, int stride_h, int padding_w, int padding_h, int input_channel_size, int filter_channel_size, int output_filter_size, int output_pixels, int grid_p, int grid_q, unsigned int magic_pq, unsigned int shift_pq, unsigned int magic_q, unsigned int shift_q, unsigned int magic_s, unsigned int shift_s) """ code = r""" { __shared__ int2 lookup_table[FILTER_SIZE]; __shared__ int lut_size; __shared__ Matrix %(a_name)s_data[NUM_ROWS][THREADS_DIM_X]; __shared__ Matrix %(b_name)s_data[NUM_ROWS][THREADS_DIM_Y]; int lut_size_local = 0; //TODO: Use square access pattern to image data to increase cache hits int output_pixel, image_id; _idiv_magic(blockIdx.x, magic_pq, shift_pq, output_pixels, image_id, output_pixel); image_id = (image_id * blockDim.x); //Zig zag along x axis to increase cache hits int temp_x, temp_y; _idiv_magic(output_pixel, magic_q, shift_q, Q, temp_y, temp_x); int output_pixel_x = (temp_y & 1) ? (Q - temp_x - 1) : temp_x; int output_pixel_y = temp_y; output_pixel = output_pixel_x + (output_pixel_y * Q); int filter_id = blockIdx.y * blockDim.y; int tid = threadIdx.x + threadIdx.y * blockDim.x; //Offset buffers based on thread id I = &(I[image_id + threadIdx.x]); F = &(F[filter_id + threadIdx.x]); %(filter_load_cond)s //Compute lookup table for filter/image data %(lut_code)s if(tid == 0) { lut_size = lut_size_local; } __syncthreads(); lut_size_local = lut_size; Matrix result[REG_TILE_Y] = {0}; output_pixel = (output_pixel * N) >> 2; if(lut_size_local > 0) { //Evaluate gemm with outer product dimensions N, K and inner product CRS int CRS = lut_size_local * C; //Compute magic numbers for division by lut_size float reciprocal = 1.0f / (float)lut_size_local; //Initialize shared mem for first block int crs = CRS %% NUM_ROWS; crs = (crs == 0) ? 8 : crs; int c, rs; _idiv_fast(CRS - threadIdx.y - 1, lut_size_local, reciprocal, c, rs); int2 lut_entry = ((threadIdx.y & 7) >= crs) ? make_int2(0, 0) : lookup_table[rs]; %(a_name)s_data[threadIdx.y][threadIdx.x].f4 = ((threadIdx.y & 7) >= crs) ? make_float4(0, 0, 0, 0) : I[(c * input_channel_size) + lut_entry.x].f4; %(b_name)s_data[threadIdx.y][threadIdx.x].f4 = %(check_filter_cond)s ((threadIdx.y & 7) >= crs) ? make_float4(0, 0, 0, 0) : F[(c * filter_channel_size) + lut_entry.y].f4; //Iterate over entire filter for(crs = CRS - crs - 1; crs > 0; crs -= NUM_ROWS) { __syncthreads(); #pragma unroll for(int i = 0; i < NUM_ROWS; i++) { Matrix load_row; Matrix load_col; load_row.f4 = %(a_name)s_data[i][threadIdx.x].f4; load_col.f4 = %(b_name)s_data[i][threadIdx.y].f4; //Accumulate product #pragma unroll for(int q_offset = 0; q_offset < REG_TILE_Y; q_offset++) { #pragma unroll for(int p_offset = 0; p_offset < REG_TILE_X; p_offset++) { result[q_offset].f[p_offset] += (load_row.f[p_offset] * load_col.f[q_offset]); } } } __syncthreads(); //Load new image data and filter weights _idiv_fast(crs - threadIdx.y, lut_size_local, reciprocal, c, rs); lut_entry = lookup_table[rs]; %(a_name)s_data[threadIdx.y][threadIdx.x].f4 = I[(c * input_channel_size) + lut_entry.x].f4; %(b_name)s_data[threadIdx.y][threadIdx.x].f4 = %(check_filter_cond)s F[(c * filter_channel_size) + lut_entry.y].f4; } __syncthreads(); //Accumulate product for last iteration #pragma unroll for(int i = 0; i < NUM_ROWS; i++) { Matrix load_row; Matrix load_col; load_row.f4 = %(a_name)s_data[i][threadIdx.x].f4; load_col.f4 = %(b_name)s_data[i][threadIdx.y].f4; //Accumulate product #pragma unroll for(int q_offset = 0; q_offset < REG_TILE_Y; q_offset++) { #pragma unroll for(int p_offset = 0; p_offset < REG_TILE_X; p_offset++) { result[q_offset].f[p_offset] += (load_row.f[p_offset] * load_col.f[q_offset]); } } } } //Store result filter_id = (filter_id + threadIdx.y) << 2; if(filter_id < K) { image_id += threadIdx.x; #pragma unroll for(int q_offset = 0; q_offset < 4; q_offset++) { if(filter_id < K) { int out_index = (filter_id * output_filter_size) + output_pixel + image_id; %(bsum_code)s Matrix cur_value = {0}; if(beta > 0.0f) { cur_value.f4 = O[out_index].f4; } result[q_offset].f[0] = (result[q_offset].f[0] * alpha) + (cur_value.f[0] * beta); result[q_offset].f[1] = (result[q_offset].f[1] * alpha) + (cur_value.f[1] * beta); result[q_offset].f[2] = (result[q_offset].f[2] * alpha) + (cur_value.f[2] * beta); result[q_offset].f[3] = (result[q_offset].f[3] * alpha) + (cur_value.f[3] * beta); O[out_index].f4 = result[q_offset].f4; } filter_id++; } } } """ update_code = r""" { __shared__ Matrix %(a_name)s_data[TILE_DIM / 4][THREADS_DIM * 4 + 4]; __shared__ Matrix %(b_name)s_data[TILE_DIM / 4][THREADS_DIM * 4 + 4]; //TODO: Use square access pattern to image data to increase cache hits int output_pixel, filter_id; _idiv_magic(blockIdx.x, magic_pq, shift_pq, output_pixels, filter_id, output_pixel); filter_id = filter_id * TILE_DIM; int load_filter_id = filter_id + threadIdx.y; int filter_pixel_id = blockIdx.y * TILE_DIM; //TODO: Zig zag along x axis to increase cache hits int output_pixel_x, output_pixel_y; _idiv_magic(output_pixel, magic_q, shift_q, grid_q, output_pixel_y, output_pixel_x); //Compute input image and filter offsets for this pixel int c, rs; int crs = filter_pixel_id + threadIdx.y; _idiv_magic(crs, MAGIC_FILTER_SIZE, SHIFT_FILTER_SIZE, FILTER_SIZE, c, rs); int filter_x, filter_y; _idiv_magic32(rs, magic_s, shift_s, S, filter_y, filter_x); int output_pixel_x_save = output_pixel_x; for(; output_pixel_y < P; output_pixel_y += grid_p) { for(output_pixel_x = output_pixel_x_save; output_pixel_x < Q; output_pixel_x += grid_q) { int base_x = output_pixel_x * stride_w - padding_w + filter_x; int base_y = output_pixel_y * stride_h - padding_h + filter_y; int crs_in_bounds = (c < C) && (base_x >= 0) && (base_x < W) && (base_y >= 0) && (base_y < H); int input_pixel = W * base_y + base_x; output_pixel = output_pixel_x + (Q * output_pixel_y); //Pre-multiply offset to simplify indexing input_pixel = (input_pixel * N) >> 2; output_pixel = (output_pixel * N) >> 2; //Evaluate gemm with outer product dimensions N, K and inner product CRS Matrix result[ITEMS_PER_THREAD] = {0}; //Load image data and transpose into shared mem //TODO: pad shared memory to avoid bank conflicts Matrix buffer; buffer.f4 = crs_in_bounds ? I[(c * input_channel_size) + input_pixel + threadIdx.x].f4 : make_float4(0, 0, 0, 0); %(a_name)s_data[threadIdx.x][ 0 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[0]; %(a_name)s_data[threadIdx.x][ 8 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[1]; %(a_name)s_data[threadIdx.x][16 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[2]; %(a_name)s_data[threadIdx.x][24 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[3]; //Load error data and transpose into shared mem buffer.f4 = (load_filter_id < K) ? F[(load_filter_id * output_filter_size) + output_pixel + threadIdx.x].f4 : make_float4(0, 0, 0, 0); %(b_name)s_data[threadIdx.x][ 0 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[0]; %(b_name)s_data[threadIdx.x][ 8 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[1]; %(b_name)s_data[threadIdx.x][16 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[2]; %(b_name)s_data[threadIdx.x][24 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[3]; //Iterate over entire minibatch for(int n = threadIdx.x + (TILE_DIM >> 2); n < (N >> 2); n += (TILE_DIM >> 2)) { __syncthreads(); #pragma unroll for(int i = 0; i < (TILE_DIM >> 2); i++) { Matrix row_image; Matrix row_error; row_image.f4 = %(a_name)s_data[i][((threadIdx.y & 3) << 3) | threadIdx.y >> 2].f4; row_error.f4 = %(b_name)s_data[i][((threadIdx.y & 3) << 3) | threadIdx.x].f4; //Accumulate product #pragma unroll for(int q_offset = 0; q_offset < ITEMS_PER_THREAD; q_offset++) { #pragma unroll for(int p_offset = 0; p_offset < ITEMS_PER_THREAD; p_offset++) { result[p_offset].f[q_offset] += (row_image.f[p_offset] * row_error.f[q_offset]); } } } __syncthreads(); //Load image data and transpose into shared mem buffer.f4 = crs_in_bounds ? I[(c * input_channel_size) + input_pixel + n].f4 : make_float4(0, 0, 0, 0); %(a_name)s_data[threadIdx.x][ 0 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[0]; %(a_name)s_data[threadIdx.x][ 8 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[1]; %(a_name)s_data[threadIdx.x][16 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[2]; %(a_name)s_data[threadIdx.x][24 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[3]; //Load error data and transpose into shared mem buffer.f4 = (load_filter_id < K) ? F[(load_filter_id * output_filter_size) + output_pixel + n].f4 : make_float4(0, 0, 0, 0); %(b_name)s_data[threadIdx.x][ 0 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[0]; %(b_name)s_data[threadIdx.x][ 8 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[1]; %(b_name)s_data[threadIdx.x][16 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[2]; %(b_name)s_data[threadIdx.x][24 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[3]; } __syncthreads(); //Accumulate product for last iteration #pragma unroll for(int i = 0; i < (TILE_DIM >> 2); i++) { Matrix row_image; Matrix row_error; row_image.f4 = %(a_name)s_data[i][((threadIdx.y & 3) << 3) | threadIdx.y >> 2].f4; row_error.f4 = %(b_name)s_data[i][((threadIdx.y & 3) << 3) | threadIdx.x].f4; //Accumulate product #pragma unroll for(int q_offset = 0; q_offset < ITEMS_PER_THREAD; q_offset++) { #pragma unroll for(int p_offset = 0; p_offset < ITEMS_PER_THREAD; p_offset++) { result[p_offset].f[q_offset] += (row_image.f[p_offset] * row_error.f[q_offset]); } } } //Reduce result between threads in warp Matrix outbound; int warp_y = threadIdx.y & 3; int warp_id = threadIdx.x + (threadIdx.y << 3); buffer.f4 = (warp_y == 0) ? result[0].f4 : (warp_y == 1) ? result[1].f4 : (warp_y == 2) ? result[2].f4 : result[3].f4; outbound.f4 = (warp_y == 0) ? result[3].f4 : (warp_y == 1) ? result[0].f4 : (warp_y == 2) ? result[1].f4 : result[2].f4; buffer.f[0] += __shfl(outbound.f[0], warp_id + 8); buffer.f[1] += __shfl(outbound.f[1], warp_id + 8); buffer.f[2] += __shfl(outbound.f[2], warp_id + 8); buffer.f[3] += __shfl(outbound.f[3], warp_id + 8); outbound.f4 = (warp_y == 0) ? result[2].f4 : (warp_y == 1) ? result[3].f4 : (warp_y == 2) ? result[0].f4 : result[1].f4; buffer.f[0] += __shfl(outbound.f[0], warp_id + 16); buffer.f[1] += __shfl(outbound.f[1], warp_id + 16); buffer.f[2] += __shfl(outbound.f[2], warp_id + 16); buffer.f[3] += __shfl(outbound.f[3], warp_id + 16); outbound.f4 = (warp_y == 0) ? result[1].f4 : (warp_y == 1) ? result[2].f4 : (warp_y == 2) ? result[3].f4 : result[0].f4; buffer.f[0] += __shfl(outbound.f[0], warp_id + 24); buffer.f[1] += __shfl(outbound.f[1], warp_id + 24); buffer.f[2] += __shfl(outbound.f[2], warp_id + 24); buffer.f[3] += __shfl(outbound.f[3], warp_id + 24); //Store result int idx_filter_id = filter_id + (threadIdx.x << 2); if(idx_filter_id < K && crs_in_bounds) { int out_index = (c * filter_channel_size) + (((rs * K) + (idx_filter_id)) >> 2); atomicAdd(&O[out_index].f[0], buffer.f[0]); atomicAdd(&O[out_index].f[1], buffer.f[1]); atomicAdd(&O[out_index].f[2], buffer.f[2]); atomicAdd(&O[out_index].f[3], buffer.f[3]); } } } } """ if operation == "update": code = header_code + update_code else: code = header_code + code magic = magic64(filter_size) code = code % { "filter_size": filter_size, "magic_filter_size": magic[0], "shift_filter_size": magic[1], "type": _ew_types[dtype]["type"], "lut_code": lut_code, "bsum_code": bsum_code, "operation": operation, "a_name": a_name, "b_name": b_name, "filter_load_cond": filter_load_cond, "check_filter_cond": check_filter_cond } options = ["--use_fast_math"] if debug and operation == "bprop": options = options + ["-g", "-G"] module = SourceModule(code, options=options) kernel = module.get_function("conv_" + operation) kernel.prepare("ffPPPPIIIIIIIIIIIIIIIIIIIIIIIIIIII") kernel.name = "conv_" + operation return kernel
def _get_conv_kernel(ctx, options, dtype, filter_size, bsum, operation, filter_bounds_check=False, debug=False): """ Builds the convolution kernel for a specified filter size. Arguments: dtype (np.dtype): The data type which the kernel will operate on. filter_size (int): Total number of elements per filter (R * S) bsum (boolean): If set to true, kernel will include code to compute batch sum during fprop operation (string): Determines which kernel to build. options follow: 'fprop': Forward propagation of activations. 'bprop': Backward propagation of error. 'update': Computes gradients for filter weights based on error and inputs. filter_bounds_check (boolean): Checks if filter weight is in bounds when K is not a multiple of 32. debug (boolean): When set to true, kernels will be compiled with debug symbols. """ assert operation in ["fprop"] assert not bsum assert operation in ["fprop", "bprop", "update"] if operation == "fprop" or operation == "update": lut_code = r""" if(tid < 32) { int rs = tid; int base_x, base_y; base_x = output_pixel_x * stride_w - padding_w; base_y = output_pixel_y * stride_h - padding_h; // This will have 1s for this tid, and all the tids below it // eg: // 1 << 4 - 1 // => 0b1111 unsigned int mask = (1 << tid) - 1; while(rs < FILTER_SIZE) { int filter_x, filter_y; _idiv_magic32(rs, magic_s, shift_s, S, &filter_y, &filter_x); int index_x = base_x + filter_x; int index_y = base_y + filter_y; //Check if the index is valid int in_bounds = (index_x >= 0 && index_x < W && index_y >= 0 && index_y < H); // from cuda manual: // __ballot(predicate) : // Evaluate predicate for all active threads of the warp and return an integer whose // Nth bit is set if and only if predicat // TODO: unsigned int threads_in_bounds = __ballot(in_bounds); //Store lookup table entry if(in_bounds) { int2 lut_entry; lut_entry.x = ((index_y * W + index_x) * N) >> 2; lut_entry.y = (rs * K) >> 2; int index = 0; // TODO: int index = lut_size_local + __popc(threads_in_bounds & mask); lookup_table[index] = lut_entry; } // TODO: lut_size_local += __popc(threads_in_bounds); rs += 32; } } """ bsum_code = "" if operation == "fprop": a_name = "image" b_name = "filter" if filter_bounds_check: filter_load_cond = "int filter_load_in_bounds = (((filter_id + get_local_id(0)) << 2) < K);" check_filter_cond = "(!filter_load_in_bounds) ? make_float4(0, 0, 0, 0) :" else: filter_load_cond = "" check_filter_cond = "" header_code = r""" #define TILE_DIM 32 #define ITEMS_PER_THREAD 4 #define THREADS_DIM 8 #define REG_TILE_X 4 #define REG_TILE_Y 4 #define THREADS_DIM_X 8 #define THREADS_DIM_Y 8 #define SM_TILE_X (REG_TILE_X * THREADS_DIM_X) #define SM_TILE_Y (REG_TILE_Y * THREADS_DIM_Y) #define NUM_ROWS 8 #define FILTER_SIZE %(filter_size)s #define MAGIC_FILTER_SIZE %(magic_filter_size)s #define SHIFT_FILTER_SIZE %(shift_filter_size)s typedef union Matrix { %(type)s4 f4; %(type)s f[4]; } Matrix; static inline void _idiv_fast(int numerator, int denominator, float rcp, int* p_result, int* p_remainder) { *p_result = (int)((float)numerator * rcp); *p_remainder = numerator - (*p_result * denominator); *p_result = (*p_remainder >= denominator) ? (*p_result + 1) : *p_result; *p_remainder = (*p_remainder >= denominator) ? (*p_remainder - denominator) : *p_remainder; } static inline void _idiv_magic(int numerator, unsigned int magic, unsigned int shift, int denominator, int* p_result, int* p_remainder) { if(magic == 1) { *p_result = numerator >> shift; } else { unsigned long long res64 = numerator * (unsigned long long)magic; *p_result = ((int)(res64 >> 32) >> shift); } *p_remainder = numerator - (*p_result * denominator); } static inline void _idiv_magic32(int numerator, unsigned int magic, unsigned int shift, int denominator, int* p_result, int* p_remainder) { if(magic == 1) { *p_result = numerator >> shift; } else { *p_result = ((numerator * magic) >> shift); } *p_remainder = numerator - (*p_result * denominator); } //Note: N and K must be multiples of 4 //get_group_id(0) is gemm tile id (K dimension) and output pixel id //get_group_id(1) is gemm tile id (N dimension) //get_local_id(0) is gemm tile offset (K dimension) //get_local_id(1) is gemm tile offset (N dimension) kernel void conv_%(operation)s( %(type)s alpha, %(type)s beta, global Matrix *I, global Matrix *F, global Matrix *O, global float* bsum, int C, int D, int H, int W, int N, int T, int R, int S, int K, int M, int P, int Q, int stride_w, int stride_h, int padding_w, int padding_h, int input_channel_size, int filter_channel_size, int output_filter_size, int output_pixels, int grid_p, int grid_q, unsigned int magic_pq, unsigned int shift_pq, unsigned int magic_q, unsigned int shift_q, unsigned int magic_s, unsigned int shift_s) """ code = r""" { local int2 lookup_table[FILTER_SIZE]; local int lut_size; local Matrix %(a_name)s_data[NUM_ROWS][THREADS_DIM_X]; local Matrix %(b_name)s_data[NUM_ROWS][THREADS_DIM_Y]; int lut_size_local = 0; //TODO: Use square access pattern to image data to increase cache hits int output_pixel, image_id; _idiv_magic(get_group_id(0), magic_pq, shift_pq, output_pixels, &image_id, &output_pixel); image_id = (image_id * get_local_size(0)); //Zig zag along x axis to increase cache hits int temp_x, temp_y; _idiv_magic(output_pixel, magic_q, shift_q, Q, &temp_y, &temp_x); int output_pixel_x = (temp_y & 1) ? (Q - temp_x - 1) : temp_x; int output_pixel_y = temp_y; output_pixel = output_pixel_x + (output_pixel_y * Q); int filter_id = get_group_id(1) * get_local_size(1); // tid is the id within the workgroup, in a flat 1d space int tid = get_local_id(0) + get_local_id(1) * get_local_size(0); //Offset buffers based on thread id I = &(I[image_id + get_local_id(0)]); F = &(F[filter_id + get_local_id(0)]); %(filter_load_cond)s //Compute lookup table for filter/image data %(lut_code)s if(tid == 0) { lut_size = lut_size_local; } barrier(CLK_LOCAL_MEM_FENCE); lut_size_local = lut_size; Matrix result[REG_TILE_Y]; output_pixel = (output_pixel * N) >> 2; if(lut_size_local > 0) { //Evaluate gemm with outer product dimensions N, K and inner product CRS int CRS = lut_size_local * C; //Compute magic numbers for division by lut_size float reciprocal = 1.0f / (float)lut_size_local; //Initialize shared mem for first block int crs = CRS %% NUM_ROWS; crs = (crs == 0) ? 8 : crs; int c, rs; _idiv_fast(CRS - get_local_id(1) - 1, lut_size_local, reciprocal, &c, &rs); int2 lut_entry = ((get_local_id(1) & 7) >= crs) ? (int2)0 : lookup_table[rs]; %(a_name)s_data[get_local_id(1)][get_local_id(0)].f4 = ((get_local_id(1) & 7) >= crs) ? (float4)0.0f : I[(c * input_channel_size) + lut_entry.x].f4; %(b_name)s_data[get_local_id(1)][get_local_id(0)].f4 = %(check_filter_cond)s ((get_local_id(1) & 7) >= crs) ? (float4)0.0f : F[(c * filter_channel_size) + lut_entry.y].f4; //Iterate over entire filter for(crs = CRS - crs - 1; crs > 0; crs -= NUM_ROWS) { barrier(CLK_LOCAL_MEM_FENCE); #pragma unroll for(int i = 0; i < NUM_ROWS; i++) { Matrix load_row; Matrix load_col; load_row.f4 = %(a_name)s_data[i][get_local_id(0)].f4; load_col.f4 = %(b_name)s_data[i][get_local_id(1)].f4; //Accumulate product #pragma unroll for(int q_offset = 0; q_offset < REG_TILE_Y; q_offset++) { #pragma unroll for(int p_offset = 0; p_offset < REG_TILE_X; p_offset++) { result[q_offset].f[p_offset] += (load_row.f[p_offset] * load_col.f[q_offset]); } } } barrier(CLK_LOCAL_MEM_FENCE); //Load new image data and filter weights _idiv_fast(crs - get_local_id(1), lut_size_local, reciprocal, &c, &rs); lut_entry = lookup_table[rs]; %(a_name)s_data[get_local_id(1)][get_local_id(0)].f4 = I[(c * input_channel_size) + lut_entry.x].f4; %(b_name)s_data[get_local_id(1)][get_local_id(0)].f4 = %(check_filter_cond)s F[(c * filter_channel_size) + lut_entry.y].f4; } barrier(CLK_LOCAL_MEM_FENCE); //Accumulate product for last iteration #pragma unroll for(int i = 0; i < NUM_ROWS; i++) { Matrix load_row; Matrix load_col; load_row.f4 = %(a_name)s_data[i][get_local_id(0)].f4; load_col.f4 = %(b_name)s_data[i][get_local_id(1)].f4; //Accumulate product #pragma unroll for(int q_offset = 0; q_offset < REG_TILE_Y; q_offset++) { #pragma unroll for(int p_offset = 0; p_offset < REG_TILE_X; p_offset++) { result[q_offset].f[p_offset] += (load_row.f[p_offset] * load_col.f[q_offset]); } } } } //Store result filter_id = (filter_id + get_local_id(1)) << 2; if(filter_id < K) { image_id += get_local_id(0); #pragma unroll for(int q_offset = 0; q_offset < 4; q_offset++) { if(filter_id < K) { int out_index = (filter_id * output_filter_size) + output_pixel + image_id; %(bsum_code)s Matrix cur_value; cur_value.f4 = (float4)0.0f; if(beta > 0.0f) { cur_value.f4 = O[out_index].f4; } result[q_offset].f[0] = (result[q_offset].f[0] * alpha) + (cur_value.f[0] * beta); result[q_offset].f[1] = (result[q_offset].f[1] * alpha) + (cur_value.f[1] * beta); result[q_offset].f[2] = (result[q_offset].f[2] * alpha) + (cur_value.f[2] * beta); result[q_offset].f[3] = (result[q_offset].f[3] * alpha) + (cur_value.f[3] * beta); O[out_index].f4 = result[q_offset].f4; } filter_id++; } } } """ code = header_code + code magic = magic64(filter_size) code = code % { "filter_size": filter_size, "magic_filter_size": magic[0], "shift_filter_size": magic[1], "type": _ew_types[dtype]["type"], "lut_code": lut_code, "bsum_code": bsum_code, "operation": operation, "a_name": a_name, "b_name": b_name, "filter_load_cond": filter_load_cond, "check_filter_cond": check_filter_cond } with open('/tmp/out.cl', 'w') as f: f.write(code) # options = ["--use_fast_math"] # if debug and operation == "bprop": # options = options + ["-g", "-G"] module = cl.Program(ctx, code).build()