Example #1
0
    def __init__(self, lib, dtype,
                 N, C, K,
                 D, H, W,
                 T, R, S,
                 M, P, Q,
                 pad_d, pad_h, pad_w,
                 str_d, str_h, str_w,
                 bsum):

        super(FpropCuda, self).__init__(lib, dtype)

        assert N % 32 == 0, "N dim must be multiple of 32"
        assert K % self.vec_size == 0, "K dim must be multiple of %d" % self.vec_size

        magic_PQ = magic64(P*Q)
        magic_Q = magic64(Q)
        magic_S = magic32(R*S+32, S)
        HWN = H * W * N
        RST = R * S * T
        KRST = K * RST
        PQ = P * Q
        PQN = PQ * N
        self.kernel = _get_conv_kernel(dtype=self.dtype.str[1:], filter_size=R*S,
                                       bsum=bsum, operation="fprop")
        grid = (PQ * (-(-N // 32)), (-(-K // 32)), 1)
        block = (8, 8, 1)
        static_kernel_args = _flatten([C, D, H, W, N, T, R, S, K, M, P, Q,
                                       str_w, str_h, pad_w, pad_h,
                                       HWN // 4, KRST // 4, PQN // 4,
                                       PQ, 0, 0,
                                       magic_PQ, magic_Q, magic_S])
        self.launch_args = [grid, block] + [None] * 7 + static_kernel_args

        self.shared = RST * 4 * 2
        self.flags = (bsum and 4)
Example #2
0
    def __init__(self, lib, dtype,
                 N, C, K,
                 D, H, W,
                 T, R, S,
                 M, P, Q,
                 pad_d, pad_h, pad_w,
                 str_d, str_h, str_w,
                 bsum):

        super(BpropCuda, self).__init__(lib, dtype)

        assert N % 32 == 0, "N dim must be multiple of 32"
        assert K % self.vec_size == 0, "K dim must be multiple of %d" % self.vec_size

        magic_HW = magic64(H*W)
        magic_W = magic64(W)
        magic_RS = magic32(R*S*T+32, R*S)
        magic_S = magic32(R*S+32, S)
        HW = H * W
        HWN = HW * N
        RST = R * S * T
        CRST = C * RST
        PQ = P * Q
        PQN = PQ * N

        self.bsum = bsum
        self.kernel = _get_conv_kernel(dtype=self.dtype.str[1:], filter_size=R*S,
                                       bsum=bsum, operation="bprop")
        grid = (HW * (-(-N // 32)), -(-C // 32), 1)
        block = (8, 8, 1)
        static_kernel_args = _flatten([K, M, P, Q, N, T, R, S, C, D, H, W,
                                       str_w, str_h, pad_w, pad_h,
                                       PQN // 4, CRST // 4, HWN // 4,
                                       HW, 0, 0,
                                       magic_HW, magic_W, magic_S])
        self.launch_args = [grid, block] + [None] * 7 + static_kernel_args

        self.shared = R*S*T * 4 * 2
        self.flags = (bsum and 4)

        # generate the kernel args for dim shuffling CTRSK => KTRSC
        shuffle_grid = (ceil_div(K, 32), ceil_div(C, 32), R*S*T)
        self.shuffle_size = C*T*R*S*K*dtype.itemsize
        self.shuffle_args = [shuffle_grid, (32, 8, 1), None, None, None]
        self.shuffle_args.extend(_flatten([
            R*S*T*K, R*S*K, S*K, K,
            R*S*T*C, R*S*C, S*C, C,
            R*S, T, R, S, magic_RS, magic_S]))

        lib.set_scratch_size(self.shuffle_size)
Example #3
0
    def __init__(self, lib, dtype,
                 N, C, K,
                 D, H, W,
                 T, R, S,
                 M, P, Q,
                 pad_d, pad_h, pad_w,
                 str_d, str_h, str_w):

        super(UpdateCuda, self).__init__(lib, dtype)

        assert N % 32 == 0, "N dim must be multiple of 32"

        HWN = H * W * N
        RS = R * S
        RST = RS * T
        KRST = K * RST
        CRSTK = KRST * C
        PQ = P * Q
        PQN = PQ * N
        magic_S = magic32(R*S+32, S)

        if lib.deterministic:
            grid_P = 1
            grid_Q = 1
            self.determ = CRSTK
        else:
            grid_P = P
            grid_Q = Q
            self.determ = 0

        pq_blocks = grid_P * grid_Q
        magic_PQ = magic64(pq_blocks)
        magic_Q = magic64(grid_Q)

        self.kernel = _get_conv_kernel(dtype=self.dtype.str[1:], filter_size=R*S,
                                       bsum=False, operation="update")
        grid = (pq_blocks * (-(-K // 32)), (-(-(C*RS) // 32)), 1)
        block = (8, 32, 1)
        static_kernel_args = _flatten([C, D, H, W, N, T, R, S, K, M, P, Q,
                                       str_w, str_h, pad_w, pad_h,
                                       HWN // 4, KRST // 4, PQN // 4,
                                       PQ, grid_P, grid_Q,
                                       magic_PQ, magic_Q, magic_S])
        self.launch_args = [grid, block] + [None] * 7 + static_kernel_args

        lib.set_scratch_size((self.determ or C*T*R*S*K)*4)
Example #4
0
def _get_conv_kernel(dtype, filter_size, bsum, operation, filter_bounds_check=False, debug=False):
    """
    Builds the convolution kernel for a specified filter size.

    Arguments:
        dtype (np.dtype): The data type which the kernel will operate on.
        filter_size (int): Total number of elements per filter (R * S)
        bsum (boolean): If set to true, kernel will include code to compute
            batch sum during fprop
        operation (string): Determines which kernel to build. options follow:
            'fprop': Forward propagation of activations.
            'bprop': Backward propagation of error.
            'update': Computes gradients for filter weights based on error and inputs.
        filter_bounds_check (boolean): Checks if filter weight is in bounds when K is
            not a multiple of 32.
        debug (boolean): When set to true, kernels will be compiled with debug symbols.
    """
    print('_get_conv_kernel dtype', dtype, 'filter_size', filter_size, 'bsum', bsum, 'operation', operation,
        'filter_bounds_check', filter_bounds_check, 'debug', debug)
    assert operation in ["fprop", "bprop", "update"]
    if operation == "fprop" or operation == "update":
        lut_code = r"""
    if(tid < 32)
    {
        int rs = tid;
        int base_x, base_y;

        base_x = output_pixel_x * stride_w - padding_w;
        base_y = output_pixel_y * stride_h - padding_h;

        unsigned int mask = (1 << tid) - 1;

        while(rs < FILTER_SIZE)
        {
            int filter_x, filter_y;
            _idiv_magic32(rs, magic_s, shift_s, S, filter_y, filter_x);

            int index_x = base_x + filter_x;
            int index_y = base_y + filter_y;

            //Check if the index is valid
            int in_bounds = (index_x >= 0 && index_x < W && index_y >= 0 && index_y < H);
            unsigned int threads_in_bounds = __ballot(in_bounds);

            //Store lookup table entry
            if(in_bounds)
            {
                int2 lut_entry;
                lut_entry.x = ((index_y * W + index_x) * N) >> 2;
                lut_entry.y = (rs * K) >> 2;

                int index = lut_size_local + __popc(threads_in_bounds & mask);
                lookup_table[index] = lut_entry;
            }

            lut_size_local += __popc(threads_in_bounds);

            rs += 32;
        }
    }
"""
    elif operation == "bprop":
        lut_code = r"""
    if(tid < 32)
    {
        int rs = tid;
        int base_q, base_p;

        base_q = output_pixel_x - (S - padding_w - 1);
        base_p = output_pixel_y - (R - padding_h - 1);

        unsigned int mask = (1 << tid) - 1;

        while(rs < FILTER_SIZE)
        {
            int filter_x, filter_y;
            _idiv_magic32(rs, magic_s, shift_s, S, filter_y, filter_x);

            int index_q = base_q + filter_x;
            int index_p = base_p + filter_y;

            //Check if the index is valid
            int in_bounds = (((index_q % stride_w) | (index_p % stride_h)) == 0);
            index_q /= stride_w;
            index_p /= stride_h;
            in_bounds = in_bounds && (index_q >= 0 && index_q < W
                                      && index_p >= 0 && index_p < H);
            unsigned int threads_in_bounds = __ballot(in_bounds);

            //Store lookup table entry
            if(in_bounds)
            {
                int2 lut_entry;
                lut_entry.x = (((index_p * W) + index_q) * N) >> 2;
                lut_entry.y = (rs * K) >> 2;

                int index = lut_size_local + __popc(threads_in_bounds & mask);
                lookup_table[index] = lut_entry;
            }

            lut_size_local += __popc(threads_in_bounds);

            rs += 32;
        }
    }
"""
    if bsum:
        bsum_code = r"""
            float local_bsum = result[q_offset].f[0] + result[q_offset].f[1] +
                               result[q_offset].f[2] + result[q_offset].f[3];
            atomicAdd(&bsum[filter_id], local_bsum);
"""
    else:
        bsum_code = ""

    if operation == "update":
        a_name = "image"
        b_name = "error"
    else:
        if operation == "fprop":
            a_name = "image"
            b_name = "filter"
        elif operation == "bprop":
            a_name = "error"
            b_name = "filter"

    if filter_bounds_check:
        filter_load_cond = "int filter_load_in_bounds = (((filter_id + threadIdx.x) << 2) < K);"
        check_filter_cond = "(!filter_load_in_bounds) ? make_float4(0, 0, 0, 0) :"
    else:
        filter_load_cond = ""
        check_filter_cond = ""

    header_code = r"""
#define TILE_DIM            32
#define ITEMS_PER_THREAD    4
#define THREADS_DIM         8

#define REG_TILE_X          4
#define REG_TILE_Y          4
#define THREADS_DIM_X       8
#define THREADS_DIM_Y       8
#define SM_TILE_X           (REG_TILE_X * THREADS_DIM_X)
#define SM_TILE_Y           (REG_TILE_Y * THREADS_DIM_Y)

#define NUM_ROWS            8
#define FILTER_SIZE         %(filter_size)s
#define MAGIC_FILTER_SIZE   %(magic_filter_size)s
#define SHIFT_FILTER_SIZE   %(shift_filter_size)s

typedef union Matrix {
    %(type)s4 f4;
    %(type)s f[4];
} Matrix;

__device__ inline void _idiv_fast(int numerator, int denominator, float rcp,
                                 int& result, int& remainder)
{
    result = (int)((float)numerator * rcp);
    remainder = numerator - (result * denominator);
    result = (remainder >= denominator) ? (result + 1) : result;
    remainder = (remainder >= denominator) ? (remainder - denominator) : remainder;
}

__device__ inline void _idiv_magic(int numerator, unsigned int magic, unsigned int shift,
                                   int denominator, int& result, int& remainder)
{
    if(magic == 1)
    {
        result = numerator >> shift;
    }
    else
    {
        unsigned long long res64 = numerator * (unsigned long long)magic;
        result = ((int)(res64 >> 32) >> shift);
    }
    remainder = numerator - (result * denominator);
}

__device__ inline void _idiv_magic32(int numerator, unsigned int magic, unsigned int shift,
                                     int denominator, int& result, int& remainder)
{
    if(magic == 1)
    {
        result = numerator >> shift;
    }
    else
    {
        result = ((numerator * magic) >> shift);
    }
    remainder = numerator - (result * denominator);
}

//Note: N and K must be multiples of 4
//blockIdx.x is gemm tile id (K dimension) and output pixel id
//blockIdx.y is gemm tile id (N dimension)
//threadIdx.x is gemm tile offset (K dimension)
//threadIdx.y is gemm tile offset (N dimension)
__global__ void conv_%(operation)s(
                           %(type)s alpha, %(type)s beta,
                           Matrix *I, Matrix *F, Matrix *O, float* bsum,
                           int C, int D, int H, int W, int N,
                           int T, int R, int S, int K,
                           int M, int P, int Q,
                           int stride_w, int stride_h, int padding_w, int padding_h,
                           int input_channel_size, int filter_channel_size,
                           int output_filter_size,
                           int output_pixels, int grid_p, int grid_q,
                           unsigned int magic_pq, unsigned int shift_pq,
                           unsigned int magic_q, unsigned int shift_q,
                           unsigned int magic_s, unsigned int shift_s)

"""
    code = r"""
{
    __shared__ int2 lookup_table[FILTER_SIZE];
    __shared__ int lut_size;
    __shared__ Matrix %(a_name)s_data[NUM_ROWS][THREADS_DIM_X];
    __shared__ Matrix %(b_name)s_data[NUM_ROWS][THREADS_DIM_Y];

    int lut_size_local = 0;

    //TODO: Use square access pattern to image data to increase cache hits
    int output_pixel, image_id;
    _idiv_magic(blockIdx.x, magic_pq, shift_pq, output_pixels, image_id, output_pixel);
    image_id = (image_id * blockDim.x);

    //Zig zag along x axis to increase cache hits
    int temp_x, temp_y;
    _idiv_magic(output_pixel, magic_q, shift_q, Q, temp_y, temp_x);
    int output_pixel_x = (temp_y & 1) ? (Q - temp_x - 1) : temp_x;
    int output_pixel_y = temp_y;
    output_pixel = output_pixel_x + (output_pixel_y * Q);

    int filter_id = blockIdx.y * blockDim.y;
    int tid = threadIdx.x + threadIdx.y * blockDim.x;

    //Offset buffers based on thread id
    I = &(I[image_id  + threadIdx.x]);
    F = &(F[filter_id + threadIdx.x]);

    %(filter_load_cond)s

    //Compute lookup table for filter/image data
%(lut_code)s

    if(tid == 0)
    {
        lut_size = lut_size_local;
    }

    __syncthreads();

    lut_size_local = lut_size;
    Matrix result[REG_TILE_Y] = {0};
    output_pixel = (output_pixel * N) >> 2;
    if(lut_size_local > 0)
    {
        //Evaluate gemm with outer product dimensions N, K and inner product CRS
        int CRS = lut_size_local * C;

        //Compute magic numbers for division by lut_size
        float reciprocal = 1.0f / (float)lut_size_local;

        //Initialize shared mem for first block
        int crs = CRS %% NUM_ROWS;
        crs = (crs == 0) ? 8 : crs;

        int c, rs;
        _idiv_fast(CRS - threadIdx.y - 1, lut_size_local, reciprocal, c, rs);

        int2 lut_entry = ((threadIdx.y & 7) >= crs) ? make_int2(0, 0) : lookup_table[rs];
        %(a_name)s_data[threadIdx.y][threadIdx.x].f4 =
            ((threadIdx.y & 7) >= crs) ? make_float4(0, 0, 0, 0) :
            I[(c * input_channel_size)  + lut_entry.x].f4;
        %(b_name)s_data[threadIdx.y][threadIdx.x].f4 = %(check_filter_cond)s
            ((threadIdx.y & 7) >= crs) ? make_float4(0, 0, 0, 0) :
            F[(c * filter_channel_size) + lut_entry.y].f4;

        //Iterate over entire filter
        for(crs = CRS - crs - 1; crs > 0; crs -= NUM_ROWS)
        {
            __syncthreads();

            #pragma unroll
            for(int i = 0; i < NUM_ROWS; i++)
            {
                Matrix load_row;
                Matrix load_col;

                load_row.f4 = %(a_name)s_data[i][threadIdx.x].f4;
                load_col.f4 = %(b_name)s_data[i][threadIdx.y].f4;

                //Accumulate product
                #pragma unroll
                for(int q_offset = 0; q_offset < REG_TILE_Y; q_offset++)
                {
                    #pragma unroll
                    for(int p_offset = 0; p_offset < REG_TILE_X; p_offset++)
                    {
                        result[q_offset].f[p_offset] += (load_row.f[p_offset] *
                                                         load_col.f[q_offset]);
                    }
                }
            }

            __syncthreads();

            //Load new image data and filter weights
            _idiv_fast(crs - threadIdx.y, lut_size_local, reciprocal, c, rs);

            lut_entry = lookup_table[rs];
            %(a_name)s_data[threadIdx.y][threadIdx.x].f4 =
                I[(c * input_channel_size)  + lut_entry.x].f4;
            %(b_name)s_data[threadIdx.y][threadIdx.x].f4 =
                %(check_filter_cond)s F[(c * filter_channel_size) + lut_entry.y].f4;
        }

        __syncthreads();

        //Accumulate product for last iteration
        #pragma unroll
        for(int i = 0; i < NUM_ROWS; i++)
        {
            Matrix load_row;
            Matrix load_col;

            load_row.f4 = %(a_name)s_data[i][threadIdx.x].f4;
            load_col.f4 = %(b_name)s_data[i][threadIdx.y].f4;

            //Accumulate product
            #pragma unroll
            for(int q_offset = 0; q_offset < REG_TILE_Y; q_offset++)
            {
                #pragma unroll
                for(int p_offset = 0; p_offset < REG_TILE_X; p_offset++)
                {
                    result[q_offset].f[p_offset] += (load_row.f[p_offset] * load_col.f[q_offset]);
                }
            }
        }
    }

    //Store result
    filter_id = (filter_id + threadIdx.y) << 2;
    if(filter_id < K)
    {
        image_id += threadIdx.x;

        #pragma unroll
        for(int q_offset = 0; q_offset < 4; q_offset++)
        {
            if(filter_id < K)
            {
                int out_index = (filter_id * output_filter_size) + output_pixel + image_id;
                %(bsum_code)s

                Matrix cur_value = {0};
                if(beta > 0.0f)
                {
                    cur_value.f4 = O[out_index].f4;
                }

                result[q_offset].f[0] = (result[q_offset].f[0] * alpha) + (cur_value.f[0] * beta);
                result[q_offset].f[1] = (result[q_offset].f[1] * alpha) + (cur_value.f[1] * beta);
                result[q_offset].f[2] = (result[q_offset].f[2] * alpha) + (cur_value.f[2] * beta);
                result[q_offset].f[3] = (result[q_offset].f[3] * alpha) + (cur_value.f[3] * beta);

                O[out_index].f4 = result[q_offset].f4;
            }
            filter_id++;
        }
    }
}
"""

    update_code = r"""
{
    __shared__ Matrix %(a_name)s_data[TILE_DIM / 4][THREADS_DIM * 4 + 4];
    __shared__ Matrix %(b_name)s_data[TILE_DIM / 4][THREADS_DIM * 4 + 4];

    //TODO: Use square access pattern to image data to increase cache hits
    int output_pixel, filter_id;
    _idiv_magic(blockIdx.x, magic_pq, shift_pq, output_pixels, filter_id, output_pixel);
    filter_id = filter_id * TILE_DIM;
    int load_filter_id = filter_id + threadIdx.y;

    int filter_pixel_id = blockIdx.y * TILE_DIM;

    //TODO: Zig zag along x axis to increase cache hits
    int output_pixel_x, output_pixel_y;
    _idiv_magic(output_pixel, magic_q, shift_q, grid_q, output_pixel_y, output_pixel_x);

    //Compute input image and filter offsets for this pixel
    int c, rs;
    int crs = filter_pixel_id + threadIdx.y;
    _idiv_magic(crs, MAGIC_FILTER_SIZE, SHIFT_FILTER_SIZE, FILTER_SIZE, c, rs);

    int filter_x, filter_y;
    _idiv_magic32(rs, magic_s, shift_s, S, filter_y, filter_x);

    int output_pixel_x_save = output_pixel_x;
    for(; output_pixel_y < P; output_pixel_y += grid_p)
    {
        for(output_pixel_x = output_pixel_x_save; output_pixel_x < Q; output_pixel_x += grid_q)
        {
            int base_x = output_pixel_x * stride_w - padding_w + filter_x;
            int base_y = output_pixel_y * stride_h - padding_h + filter_y;
            int crs_in_bounds = (c < C) && (base_x >= 0) && (base_x < W) &&
                                (base_y >= 0) && (base_y < H);
            int input_pixel = W * base_y + base_x;
            output_pixel = output_pixel_x + (Q * output_pixel_y);

            //Pre-multiply offset to simplify indexing
            input_pixel = (input_pixel * N) >> 2;
            output_pixel = (output_pixel * N) >> 2;

            //Evaluate gemm with outer product dimensions N, K and inner product CRS
            Matrix result[ITEMS_PER_THREAD] = {0};

            //Load image data and transpose into shared mem
            //TODO: pad shared memory to avoid bank conflicts
            Matrix buffer;
            buffer.f4 = crs_in_bounds ?
                        I[(c * input_channel_size) + input_pixel + threadIdx.x].f4 :
                        make_float4(0, 0, 0, 0);
            %(a_name)s_data[threadIdx.x][ 0 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[0];
            %(a_name)s_data[threadIdx.x][ 8 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[1];
            %(a_name)s_data[threadIdx.x][16 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[2];
            %(a_name)s_data[threadIdx.x][24 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[3];

            //Load error data and transpose into shared mem
            buffer.f4 = (load_filter_id < K) ?
                        F[(load_filter_id * output_filter_size) + output_pixel + threadIdx.x].f4 :
                        make_float4(0, 0, 0, 0);
            %(b_name)s_data[threadIdx.x][ 0 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[0];
            %(b_name)s_data[threadIdx.x][ 8 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[1];
            %(b_name)s_data[threadIdx.x][16 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[2];
            %(b_name)s_data[threadIdx.x][24 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[3];

            //Iterate over entire minibatch
            for(int n = threadIdx.x + (TILE_DIM >> 2); n < (N >> 2); n += (TILE_DIM >> 2))
            {
                __syncthreads();

                #pragma unroll
                for(int i = 0; i < (TILE_DIM >> 2); i++)
                {
                    Matrix row_image;
                    Matrix row_error;

                    row_image.f4 =
                        %(a_name)s_data[i][((threadIdx.y & 3) << 3) | threadIdx.y >> 2].f4;
                    row_error.f4 =
                        %(b_name)s_data[i][((threadIdx.y & 3) << 3) | threadIdx.x].f4;

                    //Accumulate product
                    #pragma unroll
                    for(int q_offset = 0; q_offset < ITEMS_PER_THREAD; q_offset++)
                    {
                        #pragma unroll
                        for(int p_offset = 0; p_offset < ITEMS_PER_THREAD; p_offset++)
                        {
                            result[p_offset].f[q_offset] +=
                                (row_image.f[p_offset] * row_error.f[q_offset]);
                        }
                    }
                }

                __syncthreads();

                //Load image data and transpose into shared mem
                buffer.f4 = crs_in_bounds ?
                    I[(c * input_channel_size) + input_pixel + n].f4 :
                    make_float4(0, 0, 0, 0);
                %(a_name)s_data[threadIdx.x][ 0 | threadIdx.y >> 2].f[threadIdx.y & 3] =
                    buffer.f[0];
                %(a_name)s_data[threadIdx.x][ 8 | threadIdx.y >> 2].f[threadIdx.y & 3] =
                    buffer.f[1];
                %(a_name)s_data[threadIdx.x][16 | threadIdx.y >> 2].f[threadIdx.y & 3] =
                    buffer.f[2];
                %(a_name)s_data[threadIdx.x][24 | threadIdx.y >> 2].f[threadIdx.y & 3] =
                    buffer.f[3];

                //Load error data and transpose into shared mem
                buffer.f4 = (load_filter_id < K) ?
                    F[(load_filter_id * output_filter_size) + output_pixel + n].f4 :
                    make_float4(0, 0, 0, 0);
                %(b_name)s_data[threadIdx.x][ 0 | threadIdx.y >> 2].f[threadIdx.y & 3] =
                    buffer.f[0];
                %(b_name)s_data[threadIdx.x][ 8 | threadIdx.y >> 2].f[threadIdx.y & 3] =
                    buffer.f[1];
                %(b_name)s_data[threadIdx.x][16 | threadIdx.y >> 2].f[threadIdx.y & 3] =
                    buffer.f[2];
                %(b_name)s_data[threadIdx.x][24 | threadIdx.y >> 2].f[threadIdx.y & 3] =
                    buffer.f[3];
            }

            __syncthreads();

            //Accumulate product for last iteration
            #pragma unroll
            for(int i = 0; i < (TILE_DIM >> 2); i++)
            {
                Matrix row_image;
                Matrix row_error;

                row_image.f4 = %(a_name)s_data[i][((threadIdx.y & 3) << 3) | threadIdx.y >> 2].f4;
                row_error.f4 = %(b_name)s_data[i][((threadIdx.y & 3) << 3) | threadIdx.x].f4;

                //Accumulate product
                #pragma unroll
                for(int q_offset = 0; q_offset < ITEMS_PER_THREAD; q_offset++)
                {
                    #pragma unroll
                    for(int p_offset = 0; p_offset < ITEMS_PER_THREAD; p_offset++)
                    {
                        result[p_offset].f[q_offset] +=
                            (row_image.f[p_offset] * row_error.f[q_offset]);
                    }
                }
            }

            //Reduce result between threads in warp
            Matrix outbound;
            int warp_y = threadIdx.y & 3;
            int warp_id = threadIdx.x + (threadIdx.y << 3);
            buffer.f4 = (warp_y == 0) ? result[0].f4 :
                        (warp_y == 1) ? result[1].f4 :
                        (warp_y == 2) ? result[2].f4 :
                        result[3].f4;

            outbound.f4 = (warp_y == 0) ? result[3].f4 :
                          (warp_y == 1) ? result[0].f4 :
                          (warp_y == 2) ? result[1].f4 :
                          result[2].f4;
            buffer.f[0] += __shfl(outbound.f[0], warp_id + 8);
            buffer.f[1] += __shfl(outbound.f[1], warp_id + 8);
            buffer.f[2] += __shfl(outbound.f[2], warp_id + 8);
            buffer.f[3] += __shfl(outbound.f[3], warp_id + 8);

            outbound.f4 = (warp_y == 0) ? result[2].f4 :
                          (warp_y == 1) ? result[3].f4 :
                          (warp_y == 2) ? result[0].f4 :
                          result[1].f4;
            buffer.f[0] += __shfl(outbound.f[0], warp_id + 16);
            buffer.f[1] += __shfl(outbound.f[1], warp_id + 16);
            buffer.f[2] += __shfl(outbound.f[2], warp_id + 16);
            buffer.f[3] += __shfl(outbound.f[3], warp_id + 16);

            outbound.f4 = (warp_y == 0) ? result[1].f4 :
                          (warp_y == 1) ? result[2].f4 :
                          (warp_y == 2) ? result[3].f4 :
                          result[0].f4;
            buffer.f[0] += __shfl(outbound.f[0], warp_id + 24);
            buffer.f[1] += __shfl(outbound.f[1], warp_id + 24);
            buffer.f[2] += __shfl(outbound.f[2], warp_id + 24);
            buffer.f[3] += __shfl(outbound.f[3], warp_id + 24);

            //Store result
            int idx_filter_id = filter_id + (threadIdx.x << 2);
            if(idx_filter_id < K && crs_in_bounds)
            {
                int out_index = (c * filter_channel_size) + (((rs * K) + (idx_filter_id)) >> 2);

                atomicAdd(&O[out_index].f[0], buffer.f[0]);
                atomicAdd(&O[out_index].f[1], buffer.f[1]);
                atomicAdd(&O[out_index].f[2], buffer.f[2]);
                atomicAdd(&O[out_index].f[3], buffer.f[3]);
            }
        }
    }
}
"""
    if operation == "update":
        code = header_code + update_code
    else:
        code = header_code + code

    magic = magic64(filter_size)

    code = code % {
        "filter_size":          filter_size,
        "magic_filter_size":    magic[0],
        "shift_filter_size":    magic[1],
        "type":                 _ew_types[dtype]["type"],
        "lut_code":             lut_code,
        "bsum_code":            bsum_code,
        "operation":            operation,
        "a_name":               a_name,
        "b_name":               b_name,
        "filter_load_cond":     filter_load_cond,
        "check_filter_cond":    check_filter_cond
    }

    options = ["--use_fast_math"]
    if debug and operation == "bprop":
        options = options + ["-g", "-G"]
    module = SourceModule(code, options=options)

    kernel = module.get_function("conv_" + operation)
    kernel.prepare("ffPPPPIIIIIIIIIIIIIIIIIIIIIIIIIIII")
    kernel.name = "conv_" + operation
    return kernel
Example #5
0
def _get_conv_kernel(ctx, options, dtype, filter_size, bsum, operation, filter_bounds_check=False, debug=False):
    """
    Builds the convolution kernel for a specified filter size.

    Arguments:
        dtype (np.dtype): The data type which the kernel will operate on.
        filter_size (int): Total number of elements per filter (R * S)
        bsum (boolean): If set to true, kernel will include code to compute
            batch sum during fprop
        operation (string): Determines which kernel to build. options follow:
            'fprop': Forward propagation of activations.
            'bprop': Backward propagation of error.
            'update': Computes gradients for filter weights based on error and inputs.
        filter_bounds_check (boolean): Checks if filter weight is in bounds when K is
            not a multiple of 32.
        debug (boolean): When set to true, kernels will be compiled with debug symbols.
    """
    assert operation in ["fprop"]
    assert not bsum
    assert operation in ["fprop", "bprop", "update"]
    if operation == "fprop" or operation == "update":
        lut_code = r"""
    if(tid < 32)
    {
        int rs = tid;
        int base_x, base_y;

        base_x = output_pixel_x * stride_w - padding_w;
        base_y = output_pixel_y * stride_h - padding_h;

        // This will have 1s for this tid, and all the tids below it
        // eg:
        //    1 << 4 - 1
        // => 0b1111
        unsigned int mask = (1 << tid) - 1;

        while(rs < FILTER_SIZE)
        {
            int filter_x, filter_y;
            _idiv_magic32(rs, magic_s, shift_s, S, &filter_y, &filter_x);

            int index_x = base_x + filter_x;
            int index_y = base_y + filter_y;

            //Check if the index is valid
            int in_bounds = (index_x >= 0 && index_x < W && index_y >= 0 && index_y < H);
            // from cuda manual:
            // __ballot(predicate) :
            // Evaluate predicate for all active threads of the warp and return an integer whose
            // Nth bit is set if and only if predicat
// TODO:            unsigned int threads_in_bounds = __ballot(in_bounds);

            //Store lookup table entry
            if(in_bounds)
            {
                int2 lut_entry;
                lut_entry.x = ((index_y * W + index_x) * N) >> 2;
                lut_entry.y = (rs * K) >> 2;

                int index = 0;
                // TODO:      int index = lut_size_local + __popc(threads_in_bounds & mask);
                lookup_table[index] = lut_entry;
            }

     // TODO:       lut_size_local += __popc(threads_in_bounds);

            rs += 32;
        }
    }
"""
    bsum_code = ""

    if operation == "fprop":
        a_name = "image"
        b_name = "filter"

    if filter_bounds_check:
        filter_load_cond = "int filter_load_in_bounds = (((filter_id + get_local_id(0)) << 2) < K);"
        check_filter_cond = "(!filter_load_in_bounds) ? make_float4(0, 0, 0, 0) :"
    else:
        filter_load_cond = ""
        check_filter_cond = ""

    header_code = r"""
#define TILE_DIM            32
#define ITEMS_PER_THREAD    4
#define THREADS_DIM         8

#define REG_TILE_X          4
#define REG_TILE_Y          4
#define THREADS_DIM_X       8
#define THREADS_DIM_Y       8
#define SM_TILE_X           (REG_TILE_X * THREADS_DIM_X)
#define SM_TILE_Y           (REG_TILE_Y * THREADS_DIM_Y)

#define NUM_ROWS            8
#define FILTER_SIZE         %(filter_size)s
#define MAGIC_FILTER_SIZE   %(magic_filter_size)s
#define SHIFT_FILTER_SIZE   %(shift_filter_size)s

typedef union Matrix {
    %(type)s4 f4;
    %(type)s f[4];
} Matrix;

static inline void _idiv_fast(int numerator, int denominator, float rcp,
                                 int* p_result, int* p_remainder)
{
    *p_result = (int)((float)numerator * rcp);
    *p_remainder = numerator - (*p_result * denominator);
    *p_result = (*p_remainder >= denominator) ? (*p_result + 1) : *p_result;
    *p_remainder = (*p_remainder >= denominator) ? (*p_remainder - denominator) : *p_remainder;
}

static inline void _idiv_magic(int numerator, unsigned int magic, unsigned int shift,
                                   int denominator, int* p_result, int* p_remainder)
{
    if(magic == 1)
    {
        *p_result = numerator >> shift;
    }
    else
    {
        unsigned long long res64 = numerator * (unsigned long long)magic;
        *p_result = ((int)(res64 >> 32) >> shift);
    }
    *p_remainder = numerator - (*p_result * denominator);
}

static inline void _idiv_magic32(int numerator, unsigned int magic, unsigned int shift,
                                     int denominator, int* p_result, int* p_remainder)
{
    if(magic == 1)
    {
        *p_result = numerator >> shift;
    }
    else
    {
        *p_result = ((numerator * magic) >> shift);
    }
    *p_remainder = numerator - (*p_result * denominator);
}

//Note: N and K must be multiples of 4
//get_group_id(0) is gemm tile id (K dimension) and output pixel id
//get_group_id(1) is gemm tile id (N dimension)
//get_local_id(0) is gemm tile offset (K dimension)
//get_local_id(1) is gemm tile offset (N dimension)
kernel void conv_%(operation)s(
                           %(type)s alpha, %(type)s beta,
                           global Matrix *I,
                           global Matrix *F,
                           global Matrix *O,
                           global float* bsum,
                           int C, int D, int H, int W, int N,
                           int T, int R, int S, int K,
                           int M, int P, int Q,
                           int stride_w, int stride_h, int padding_w, int padding_h,
                           int input_channel_size, int filter_channel_size,
                           int output_filter_size,
                           int output_pixels, int grid_p, int grid_q,
                           unsigned int magic_pq, unsigned int shift_pq,
                           unsigned int magic_q, unsigned int shift_q,
                           unsigned int magic_s, unsigned int shift_s)

"""
    code = r"""
{
    local int2 lookup_table[FILTER_SIZE];
    local int lut_size;
    local Matrix %(a_name)s_data[NUM_ROWS][THREADS_DIM_X];
    local Matrix %(b_name)s_data[NUM_ROWS][THREADS_DIM_Y];

    int lut_size_local = 0;

    //TODO: Use square access pattern to image data to increase cache hits
    int output_pixel, image_id;
    _idiv_magic(get_group_id(0), magic_pq, shift_pq, output_pixels, &image_id, &output_pixel);
    image_id = (image_id * get_local_size(0));

    //Zig zag along x axis to increase cache hits
    int temp_x, temp_y;
    _idiv_magic(output_pixel, magic_q, shift_q, Q, &temp_y, &temp_x);
    int output_pixel_x = (temp_y & 1) ? (Q - temp_x - 1) : temp_x;
    int output_pixel_y = temp_y;
    output_pixel = output_pixel_x + (output_pixel_y * Q);

    int filter_id = get_group_id(1) * get_local_size(1);
    // tid is the id within the workgroup, in a flat 1d space
    int tid = get_local_id(0) + get_local_id(1) * get_local_size(0);

    //Offset buffers based on thread id
    I = &(I[image_id  + get_local_id(0)]);
    F = &(F[filter_id + get_local_id(0)]);

    %(filter_load_cond)s

    //Compute lookup table for filter/image data
%(lut_code)s

    if(tid == 0)
    {
        lut_size = lut_size_local;
    }

    barrier(CLK_LOCAL_MEM_FENCE);

    lut_size_local = lut_size;
    Matrix result[REG_TILE_Y];
    output_pixel = (output_pixel * N) >> 2;
    if(lut_size_local > 0)
    {
        //Evaluate gemm with outer product dimensions N, K and inner product CRS
        int CRS = lut_size_local * C;

        //Compute magic numbers for division by lut_size
        float reciprocal = 1.0f / (float)lut_size_local;

        //Initialize shared mem for first block
        int crs = CRS %% NUM_ROWS;
        crs = (crs == 0) ? 8 : crs;

        int c, rs;
        _idiv_fast(CRS - get_local_id(1) - 1, lut_size_local, reciprocal, &c, &rs);

        int2 lut_entry = ((get_local_id(1) & 7) >= crs) ? (int2)0 : lookup_table[rs];
        %(a_name)s_data[get_local_id(1)][get_local_id(0)].f4 =
            ((get_local_id(1) & 7) >= crs) ? (float4)0.0f :
            I[(c * input_channel_size)  + lut_entry.x].f4;
        %(b_name)s_data[get_local_id(1)][get_local_id(0)].f4 = %(check_filter_cond)s
            ((get_local_id(1) & 7) >= crs) ? (float4)0.0f :
            F[(c * filter_channel_size) + lut_entry.y].f4;

        //Iterate over entire filter
        for(crs = CRS - crs - 1; crs > 0; crs -= NUM_ROWS)
        {
            barrier(CLK_LOCAL_MEM_FENCE);

            #pragma unroll
            for(int i = 0; i < NUM_ROWS; i++)
            {
                Matrix load_row;
                Matrix load_col;

                load_row.f4 = %(a_name)s_data[i][get_local_id(0)].f4;
                load_col.f4 = %(b_name)s_data[i][get_local_id(1)].f4;

                //Accumulate product
                #pragma unroll
                for(int q_offset = 0; q_offset < REG_TILE_Y; q_offset++)
                {
                    #pragma unroll
                    for(int p_offset = 0; p_offset < REG_TILE_X; p_offset++)
                    {
                        result[q_offset].f[p_offset] += (load_row.f[p_offset] *
                                                         load_col.f[q_offset]);
                    }
                }
            }

            barrier(CLK_LOCAL_MEM_FENCE);

            //Load new image data and filter weights
            _idiv_fast(crs - get_local_id(1), lut_size_local, reciprocal, &c, &rs);

            lut_entry = lookup_table[rs];
            %(a_name)s_data[get_local_id(1)][get_local_id(0)].f4 =
                I[(c * input_channel_size)  + lut_entry.x].f4;
            %(b_name)s_data[get_local_id(1)][get_local_id(0)].f4 =
                %(check_filter_cond)s F[(c * filter_channel_size) + lut_entry.y].f4;
        }

        barrier(CLK_LOCAL_MEM_FENCE);

        //Accumulate product for last iteration
        #pragma unroll
        for(int i = 0; i < NUM_ROWS; i++)
        {
            Matrix load_row;
            Matrix load_col;

            load_row.f4 = %(a_name)s_data[i][get_local_id(0)].f4;
            load_col.f4 = %(b_name)s_data[i][get_local_id(1)].f4;

            //Accumulate product
            #pragma unroll
            for(int q_offset = 0; q_offset < REG_TILE_Y; q_offset++)
            {
                #pragma unroll
                for(int p_offset = 0; p_offset < REG_TILE_X; p_offset++)
                {
                    result[q_offset].f[p_offset] += (load_row.f[p_offset] * load_col.f[q_offset]);
                }
            }
        }
    }

    //Store result
    filter_id = (filter_id + get_local_id(1)) << 2;
    if(filter_id < K)
    {
        image_id += get_local_id(0);

        #pragma unroll
        for(int q_offset = 0; q_offset < 4; q_offset++)
        {
            if(filter_id < K)
            {
                int out_index = (filter_id * output_filter_size) + output_pixel + image_id;
                %(bsum_code)s

                Matrix cur_value;
                cur_value.f4 = (float4)0.0f;
                if(beta > 0.0f)
                {
                    cur_value.f4 = O[out_index].f4;
                }

                result[q_offset].f[0] = (result[q_offset].f[0] * alpha) + (cur_value.f[0] * beta);
                result[q_offset].f[1] = (result[q_offset].f[1] * alpha) + (cur_value.f[1] * beta);
                result[q_offset].f[2] = (result[q_offset].f[2] * alpha) + (cur_value.f[2] * beta);
                result[q_offset].f[3] = (result[q_offset].f[3] * alpha) + (cur_value.f[3] * beta);

                O[out_index].f4 = result[q_offset].f4;
            }
            filter_id++;
        }
    }
}
"""

    code = header_code + code

    magic = magic64(filter_size)

    code = code % {
        "filter_size":          filter_size,
        "magic_filter_size":    magic[0],
        "shift_filter_size":    magic[1],
        "type":                 _ew_types[dtype]["type"],
        "lut_code":             lut_code,
        "bsum_code":            bsum_code,
        "operation":            operation,
        "a_name":               a_name,
        "b_name":               b_name,
        "filter_load_cond":     filter_load_cond,
        "check_filter_cond":    check_filter_cond
    }
    with open('/tmp/out.cl', 'w') as f:
        f.write(code)

#    options = ["--use_fast_math"]
#    if debug and operation == "bprop":
#        options = options + ["-g", "-G"]
    module = cl.Program(ctx, code).build()
Example #6
0
def _get_conv_kernel(ctx,
                     options,
                     dtype,
                     filter_size,
                     bsum,
                     operation,
                     filter_bounds_check=False,
                     debug=False):
    """
    Builds the convolution kernel for a specified filter size.

    Arguments:
        dtype (np.dtype): The data type which the kernel will operate on.
        filter_size (int): Total number of elements per filter (R * S)
        bsum (boolean): If set to true, kernel will include code to compute
            batch sum during fprop
        operation (string): Determines which kernel to build. options follow:
            'fprop': Forward propagation of activations.
            'bprop': Backward propagation of error.
            'update': Computes gradients for filter weights based on error and inputs.
        filter_bounds_check (boolean): Checks if filter weight is in bounds when K is
            not a multiple of 32.
        debug (boolean): When set to true, kernels will be compiled with debug symbols.
    """
    assert operation in ["fprop"]
    assert not bsum
    assert operation in ["fprop", "bprop", "update"]
    if operation == "fprop" or operation == "update":
        lut_code = r"""
    if(tid < 32)
    {
        int rs = tid;
        int base_x, base_y;

        base_x = output_pixel_x * stride_w - padding_w;
        base_y = output_pixel_y * stride_h - padding_h;

        // This will have 1s for this tid, and all the tids below it
        // eg:
        //    1 << 4 - 1
        // => 0b1111
        unsigned int mask = (1 << tid) - 1;

        while(rs < FILTER_SIZE)
        {
            int filter_x, filter_y;
            _idiv_magic32(rs, magic_s, shift_s, S, &filter_y, &filter_x);

            int index_x = base_x + filter_x;
            int index_y = base_y + filter_y;

            //Check if the index is valid
            int in_bounds = (index_x >= 0 && index_x < W && index_y >= 0 && index_y < H);
            // from cuda manual:
            // __ballot(predicate) :
            // Evaluate predicate for all active threads of the warp and return an integer whose
            // Nth bit is set if and only if predicat
// TODO:            unsigned int threads_in_bounds = __ballot(in_bounds);

            //Store lookup table entry
            if(in_bounds)
            {
                int2 lut_entry;
                lut_entry.x = ((index_y * W + index_x) * N) >> 2;
                lut_entry.y = (rs * K) >> 2;

                int index = 0;
                // TODO:      int index = lut_size_local + __popc(threads_in_bounds & mask);
                lookup_table[index] = lut_entry;
            }

     // TODO:       lut_size_local += __popc(threads_in_bounds);

            rs += 32;
        }
    }
"""
    bsum_code = ""

    if operation == "fprop":
        a_name = "image"
        b_name = "filter"

    if filter_bounds_check:
        filter_load_cond = "int filter_load_in_bounds = (((filter_id + get_local_id(0)) << 2) < K);"
        check_filter_cond = "(!filter_load_in_bounds) ? make_float4(0, 0, 0, 0) :"
    else:
        filter_load_cond = ""
        check_filter_cond = ""

    header_code = r"""
#define TILE_DIM            32
#define ITEMS_PER_THREAD    4
#define THREADS_DIM         8

#define REG_TILE_X          4
#define REG_TILE_Y          4
#define THREADS_DIM_X       8
#define THREADS_DIM_Y       8
#define SM_TILE_X           (REG_TILE_X * THREADS_DIM_X)
#define SM_TILE_Y           (REG_TILE_Y * THREADS_DIM_Y)

#define NUM_ROWS            8
#define FILTER_SIZE         %(filter_size)s
#define MAGIC_FILTER_SIZE   %(magic_filter_size)s
#define SHIFT_FILTER_SIZE   %(shift_filter_size)s

typedef union Matrix {
    %(type)s4 f4;
    %(type)s f[4];
} Matrix;

static inline void _idiv_fast(int numerator, int denominator, float rcp,
                                 int* p_result, int* p_remainder)
{
    *p_result = (int)((float)numerator * rcp);
    *p_remainder = numerator - (*p_result * denominator);
    *p_result = (*p_remainder >= denominator) ? (*p_result + 1) : *p_result;
    *p_remainder = (*p_remainder >= denominator) ? (*p_remainder - denominator) : *p_remainder;
}

static inline void _idiv_magic(int numerator, unsigned int magic, unsigned int shift,
                                   int denominator, int* p_result, int* p_remainder)
{
    if(magic == 1)
    {
        *p_result = numerator >> shift;
    }
    else
    {
        unsigned long long res64 = numerator * (unsigned long long)magic;
        *p_result = ((int)(res64 >> 32) >> shift);
    }
    *p_remainder = numerator - (*p_result * denominator);
}

static inline void _idiv_magic32(int numerator, unsigned int magic, unsigned int shift,
                                     int denominator, int* p_result, int* p_remainder)
{
    if(magic == 1)
    {
        *p_result = numerator >> shift;
    }
    else
    {
        *p_result = ((numerator * magic) >> shift);
    }
    *p_remainder = numerator - (*p_result * denominator);
}

//Note: N and K must be multiples of 4
//get_group_id(0) is gemm tile id (K dimension) and output pixel id
//get_group_id(1) is gemm tile id (N dimension)
//get_local_id(0) is gemm tile offset (K dimension)
//get_local_id(1) is gemm tile offset (N dimension)
kernel void conv_%(operation)s(
                           %(type)s alpha, %(type)s beta,
                           global Matrix *I,
                           global Matrix *F,
                           global Matrix *O,
                           global float* bsum,
                           int C, int D, int H, int W, int N,
                           int T, int R, int S, int K,
                           int M, int P, int Q,
                           int stride_w, int stride_h, int padding_w, int padding_h,
                           int input_channel_size, int filter_channel_size,
                           int output_filter_size,
                           int output_pixels, int grid_p, int grid_q,
                           unsigned int magic_pq, unsigned int shift_pq,
                           unsigned int magic_q, unsigned int shift_q,
                           unsigned int magic_s, unsigned int shift_s)

"""
    code = r"""
{
    local int2 lookup_table[FILTER_SIZE];
    local int lut_size;
    local Matrix %(a_name)s_data[NUM_ROWS][THREADS_DIM_X];
    local Matrix %(b_name)s_data[NUM_ROWS][THREADS_DIM_Y];

    int lut_size_local = 0;

    //TODO: Use square access pattern to image data to increase cache hits
    int output_pixel, image_id;
    _idiv_magic(get_group_id(0), magic_pq, shift_pq, output_pixels, &image_id, &output_pixel);
    image_id = (image_id * get_local_size(0));

    //Zig zag along x axis to increase cache hits
    int temp_x, temp_y;
    _idiv_magic(output_pixel, magic_q, shift_q, Q, &temp_y, &temp_x);
    int output_pixel_x = (temp_y & 1) ? (Q - temp_x - 1) : temp_x;
    int output_pixel_y = temp_y;
    output_pixel = output_pixel_x + (output_pixel_y * Q);

    int filter_id = get_group_id(1) * get_local_size(1);
    // tid is the id within the workgroup, in a flat 1d space
    int tid = get_local_id(0) + get_local_id(1) * get_local_size(0);

    //Offset buffers based on thread id
    I = &(I[image_id  + get_local_id(0)]);
    F = &(F[filter_id + get_local_id(0)]);

    %(filter_load_cond)s

    //Compute lookup table for filter/image data
%(lut_code)s

    if(tid == 0)
    {
        lut_size = lut_size_local;
    }

    barrier(CLK_LOCAL_MEM_FENCE);

    lut_size_local = lut_size;
    Matrix result[REG_TILE_Y];
    output_pixel = (output_pixel * N) >> 2;
    if(lut_size_local > 0)
    {
        //Evaluate gemm with outer product dimensions N, K and inner product CRS
        int CRS = lut_size_local * C;

        //Compute magic numbers for division by lut_size
        float reciprocal = 1.0f / (float)lut_size_local;

        //Initialize shared mem for first block
        int crs = CRS %% NUM_ROWS;
        crs = (crs == 0) ? 8 : crs;

        int c, rs;
        _idiv_fast(CRS - get_local_id(1) - 1, lut_size_local, reciprocal, &c, &rs);

        int2 lut_entry = ((get_local_id(1) & 7) >= crs) ? (int2)0 : lookup_table[rs];
        %(a_name)s_data[get_local_id(1)][get_local_id(0)].f4 =
            ((get_local_id(1) & 7) >= crs) ? (float4)0.0f :
            I[(c * input_channel_size)  + lut_entry.x].f4;
        %(b_name)s_data[get_local_id(1)][get_local_id(0)].f4 = %(check_filter_cond)s
            ((get_local_id(1) & 7) >= crs) ? (float4)0.0f :
            F[(c * filter_channel_size) + lut_entry.y].f4;

        //Iterate over entire filter
        for(crs = CRS - crs - 1; crs > 0; crs -= NUM_ROWS)
        {
            barrier(CLK_LOCAL_MEM_FENCE);

            #pragma unroll
            for(int i = 0; i < NUM_ROWS; i++)
            {
                Matrix load_row;
                Matrix load_col;

                load_row.f4 = %(a_name)s_data[i][get_local_id(0)].f4;
                load_col.f4 = %(b_name)s_data[i][get_local_id(1)].f4;

                //Accumulate product
                #pragma unroll
                for(int q_offset = 0; q_offset < REG_TILE_Y; q_offset++)
                {
                    #pragma unroll
                    for(int p_offset = 0; p_offset < REG_TILE_X; p_offset++)
                    {
                        result[q_offset].f[p_offset] += (load_row.f[p_offset] *
                                                         load_col.f[q_offset]);
                    }
                }
            }

            barrier(CLK_LOCAL_MEM_FENCE);

            //Load new image data and filter weights
            _idiv_fast(crs - get_local_id(1), lut_size_local, reciprocal, &c, &rs);

            lut_entry = lookup_table[rs];
            %(a_name)s_data[get_local_id(1)][get_local_id(0)].f4 =
                I[(c * input_channel_size)  + lut_entry.x].f4;
            %(b_name)s_data[get_local_id(1)][get_local_id(0)].f4 =
                %(check_filter_cond)s F[(c * filter_channel_size) + lut_entry.y].f4;
        }

        barrier(CLK_LOCAL_MEM_FENCE);

        //Accumulate product for last iteration
        #pragma unroll
        for(int i = 0; i < NUM_ROWS; i++)
        {
            Matrix load_row;
            Matrix load_col;

            load_row.f4 = %(a_name)s_data[i][get_local_id(0)].f4;
            load_col.f4 = %(b_name)s_data[i][get_local_id(1)].f4;

            //Accumulate product
            #pragma unroll
            for(int q_offset = 0; q_offset < REG_TILE_Y; q_offset++)
            {
                #pragma unroll
                for(int p_offset = 0; p_offset < REG_TILE_X; p_offset++)
                {
                    result[q_offset].f[p_offset] += (load_row.f[p_offset] * load_col.f[q_offset]);
                }
            }
        }
    }

    //Store result
    filter_id = (filter_id + get_local_id(1)) << 2;
    if(filter_id < K)
    {
        image_id += get_local_id(0);

        #pragma unroll
        for(int q_offset = 0; q_offset < 4; q_offset++)
        {
            if(filter_id < K)
            {
                int out_index = (filter_id * output_filter_size) + output_pixel + image_id;
                %(bsum_code)s

                Matrix cur_value;
                cur_value.f4 = (float4)0.0f;
                if(beta > 0.0f)
                {
                    cur_value.f4 = O[out_index].f4;
                }

                result[q_offset].f[0] = (result[q_offset].f[0] * alpha) + (cur_value.f[0] * beta);
                result[q_offset].f[1] = (result[q_offset].f[1] * alpha) + (cur_value.f[1] * beta);
                result[q_offset].f[2] = (result[q_offset].f[2] * alpha) + (cur_value.f[2] * beta);
                result[q_offset].f[3] = (result[q_offset].f[3] * alpha) + (cur_value.f[3] * beta);

                O[out_index].f4 = result[q_offset].f4;
            }
            filter_id++;
        }
    }
}
"""

    code = header_code + code

    magic = magic64(filter_size)

    code = code % {
        "filter_size": filter_size,
        "magic_filter_size": magic[0],
        "shift_filter_size": magic[1],
        "type": _ew_types[dtype]["type"],
        "lut_code": lut_code,
        "bsum_code": bsum_code,
        "operation": operation,
        "a_name": a_name,
        "b_name": b_name,
        "filter_load_cond": filter_load_cond,
        "check_filter_cond": check_filter_cond
    }
    with open('/tmp/out.cl', 'w') as f:
        f.write(code)

#    options = ["--use_fast_math"]
#    if debug and operation == "bprop":
#        options = options + ["-g", "-G"]
    module = cl.Program(ctx, code).build()
Example #7
0
def _get_conv_kernel(dtype,
                     filter_size,
                     bsum,
                     operation,
                     filter_bounds_check=False,
                     debug=False):
    """
    Builds the convolution kernel for a specified filter size.

    Arguments:
        dtype (np.dtype): The data type which the kernel will operate on.
        filter_size (int): Total number of elements per filter (R * S)
        bsum (boolean): If set to true, kernel will include code to compute
            batch sum during fprop
        operation (string): Determines which kernel to build. options follow:
            'fprop': Forward propagation of activations.
            'bprop': Backward propagation of error.
            'update': Computes gradients for filter weights based on error and inputs.
        filter_bounds_check (boolean): Checks if filter weight is in bounds when K is
            not a multiple of 32.
        debug (boolean): When set to true, kernels will be compiled with debug symbols.
    """
    print('_get_conv_kernel dtype', dtype, 'filter_size', filter_size, 'bsum',
          bsum, 'operation', operation, 'filter_bounds_check',
          filter_bounds_check, 'debug', debug)
    assert operation in ["fprop", "bprop", "update"]
    if operation == "fprop" or operation == "update":
        lut_code = r"""
    if(tid < 32)
    {
        int rs = tid;
        int base_x, base_y;

        base_x = output_pixel_x * stride_w - padding_w;
        base_y = output_pixel_y * stride_h - padding_h;

        unsigned int mask = (1 << tid) - 1;

        while(rs < FILTER_SIZE)
        {
            int filter_x, filter_y;
            _idiv_magic32(rs, magic_s, shift_s, S, filter_y, filter_x);

            int index_x = base_x + filter_x;
            int index_y = base_y + filter_y;

            //Check if the index is valid
            int in_bounds = (index_x >= 0 && index_x < W && index_y >= 0 && index_y < H);
            unsigned int threads_in_bounds = __ballot(in_bounds);

            //Store lookup table entry
            if(in_bounds)
            {
                int2 lut_entry;
                lut_entry.x = ((index_y * W + index_x) * N) >> 2;
                lut_entry.y = (rs * K) >> 2;

                int index = lut_size_local + __popc(threads_in_bounds & mask);
                lookup_table[index] = lut_entry;
            }

            lut_size_local += __popc(threads_in_bounds);

            rs += 32;
        }
    }
"""
    elif operation == "bprop":
        lut_code = r"""
    if(tid < 32)
    {
        int rs = tid;
        int base_q, base_p;

        base_q = output_pixel_x - (S - padding_w - 1);
        base_p = output_pixel_y - (R - padding_h - 1);

        unsigned int mask = (1 << tid) - 1;

        while(rs < FILTER_SIZE)
        {
            int filter_x, filter_y;
            _idiv_magic32(rs, magic_s, shift_s, S, filter_y, filter_x);

            int index_q = base_q + filter_x;
            int index_p = base_p + filter_y;

            //Check if the index is valid
            int in_bounds = (((index_q % stride_w) | (index_p % stride_h)) == 0);
            index_q /= stride_w;
            index_p /= stride_h;
            in_bounds = in_bounds && (index_q >= 0 && index_q < W
                                      && index_p >= 0 && index_p < H);
            unsigned int threads_in_bounds = __ballot(in_bounds);

            //Store lookup table entry
            if(in_bounds)
            {
                int2 lut_entry;
                lut_entry.x = (((index_p * W) + index_q) * N) >> 2;
                lut_entry.y = (rs * K) >> 2;

                int index = lut_size_local + __popc(threads_in_bounds & mask);
                lookup_table[index] = lut_entry;
            }

            lut_size_local += __popc(threads_in_bounds);

            rs += 32;
        }
    }
"""
    if bsum:
        bsum_code = r"""
            float local_bsum = result[q_offset].f[0] + result[q_offset].f[1] +
                               result[q_offset].f[2] + result[q_offset].f[3];
            atomicAdd(&bsum[filter_id], local_bsum);
"""
    else:
        bsum_code = ""

    if operation == "update":
        a_name = "image"
        b_name = "error"
    else:
        if operation == "fprop":
            a_name = "image"
            b_name = "filter"
        elif operation == "bprop":
            a_name = "error"
            b_name = "filter"

    if filter_bounds_check:
        filter_load_cond = "int filter_load_in_bounds = (((filter_id + threadIdx.x) << 2) < K);"
        check_filter_cond = "(!filter_load_in_bounds) ? make_float4(0, 0, 0, 0) :"
    else:
        filter_load_cond = ""
        check_filter_cond = ""

    header_code = r"""
#define TILE_DIM            32
#define ITEMS_PER_THREAD    4
#define THREADS_DIM         8

#define REG_TILE_X          4
#define REG_TILE_Y          4
#define THREADS_DIM_X       8
#define THREADS_DIM_Y       8
#define SM_TILE_X           (REG_TILE_X * THREADS_DIM_X)
#define SM_TILE_Y           (REG_TILE_Y * THREADS_DIM_Y)

#define NUM_ROWS            8
#define FILTER_SIZE         %(filter_size)s
#define MAGIC_FILTER_SIZE   %(magic_filter_size)s
#define SHIFT_FILTER_SIZE   %(shift_filter_size)s

typedef union Matrix {
    %(type)s4 f4;
    %(type)s f[4];
} Matrix;

__device__ inline void _idiv_fast(int numerator, int denominator, float rcp,
                                 int& result, int& remainder)
{
    result = (int)((float)numerator * rcp);
    remainder = numerator - (result * denominator);
    result = (remainder >= denominator) ? (result + 1) : result;
    remainder = (remainder >= denominator) ? (remainder - denominator) : remainder;
}

__device__ inline void _idiv_magic(int numerator, unsigned int magic, unsigned int shift,
                                   int denominator, int& result, int& remainder)
{
    if(magic == 1)
    {
        result = numerator >> shift;
    }
    else
    {
        unsigned long long res64 = numerator * (unsigned long long)magic;
        result = ((int)(res64 >> 32) >> shift);
    }
    remainder = numerator - (result * denominator);
}

__device__ inline void _idiv_magic32(int numerator, unsigned int magic, unsigned int shift,
                                     int denominator, int& result, int& remainder)
{
    if(magic == 1)
    {
        result = numerator >> shift;
    }
    else
    {
        result = ((numerator * magic) >> shift);
    }
    remainder = numerator - (result * denominator);
}

//Note: N and K must be multiples of 4
//blockIdx.x is gemm tile id (K dimension) and output pixel id
//blockIdx.y is gemm tile id (N dimension)
//threadIdx.x is gemm tile offset (K dimension)
//threadIdx.y is gemm tile offset (N dimension)
__global__ void conv_%(operation)s(
                           %(type)s alpha, %(type)s beta,
                           Matrix *I, Matrix *F, Matrix *O, float* bsum,
                           int C, int D, int H, int W, int N,
                           int T, int R, int S, int K,
                           int M, int P, int Q,
                           int stride_w, int stride_h, int padding_w, int padding_h,
                           int input_channel_size, int filter_channel_size,
                           int output_filter_size,
                           int output_pixels, int grid_p, int grid_q,
                           unsigned int magic_pq, unsigned int shift_pq,
                           unsigned int magic_q, unsigned int shift_q,
                           unsigned int magic_s, unsigned int shift_s)

"""
    code = r"""
{
    __shared__ int2 lookup_table[FILTER_SIZE];
    __shared__ int lut_size;
    __shared__ Matrix %(a_name)s_data[NUM_ROWS][THREADS_DIM_X];
    __shared__ Matrix %(b_name)s_data[NUM_ROWS][THREADS_DIM_Y];

    int lut_size_local = 0;

    //TODO: Use square access pattern to image data to increase cache hits
    int output_pixel, image_id;
    _idiv_magic(blockIdx.x, magic_pq, shift_pq, output_pixels, image_id, output_pixel);
    image_id = (image_id * blockDim.x);

    //Zig zag along x axis to increase cache hits
    int temp_x, temp_y;
    _idiv_magic(output_pixel, magic_q, shift_q, Q, temp_y, temp_x);
    int output_pixel_x = (temp_y & 1) ? (Q - temp_x - 1) : temp_x;
    int output_pixel_y = temp_y;
    output_pixel = output_pixel_x + (output_pixel_y * Q);

    int filter_id = blockIdx.y * blockDim.y;
    int tid = threadIdx.x + threadIdx.y * blockDim.x;

    //Offset buffers based on thread id
    I = &(I[image_id  + threadIdx.x]);
    F = &(F[filter_id + threadIdx.x]);

    %(filter_load_cond)s

    //Compute lookup table for filter/image data
%(lut_code)s

    if(tid == 0)
    {
        lut_size = lut_size_local;
    }

    __syncthreads();

    lut_size_local = lut_size;
    Matrix result[REG_TILE_Y] = {0};
    output_pixel = (output_pixel * N) >> 2;
    if(lut_size_local > 0)
    {
        //Evaluate gemm with outer product dimensions N, K and inner product CRS
        int CRS = lut_size_local * C;

        //Compute magic numbers for division by lut_size
        float reciprocal = 1.0f / (float)lut_size_local;

        //Initialize shared mem for first block
        int crs = CRS %% NUM_ROWS;
        crs = (crs == 0) ? 8 : crs;

        int c, rs;
        _idiv_fast(CRS - threadIdx.y - 1, lut_size_local, reciprocal, c, rs);

        int2 lut_entry = ((threadIdx.y & 7) >= crs) ? make_int2(0, 0) : lookup_table[rs];
        %(a_name)s_data[threadIdx.y][threadIdx.x].f4 =
            ((threadIdx.y & 7) >= crs) ? make_float4(0, 0, 0, 0) :
            I[(c * input_channel_size)  + lut_entry.x].f4;
        %(b_name)s_data[threadIdx.y][threadIdx.x].f4 = %(check_filter_cond)s
            ((threadIdx.y & 7) >= crs) ? make_float4(0, 0, 0, 0) :
            F[(c * filter_channel_size) + lut_entry.y].f4;

        //Iterate over entire filter
        for(crs = CRS - crs - 1; crs > 0; crs -= NUM_ROWS)
        {
            __syncthreads();

            #pragma unroll
            for(int i = 0; i < NUM_ROWS; i++)
            {
                Matrix load_row;
                Matrix load_col;

                load_row.f4 = %(a_name)s_data[i][threadIdx.x].f4;
                load_col.f4 = %(b_name)s_data[i][threadIdx.y].f4;

                //Accumulate product
                #pragma unroll
                for(int q_offset = 0; q_offset < REG_TILE_Y; q_offset++)
                {
                    #pragma unroll
                    for(int p_offset = 0; p_offset < REG_TILE_X; p_offset++)
                    {
                        result[q_offset].f[p_offset] += (load_row.f[p_offset] *
                                                         load_col.f[q_offset]);
                    }
                }
            }

            __syncthreads();

            //Load new image data and filter weights
            _idiv_fast(crs - threadIdx.y, lut_size_local, reciprocal, c, rs);

            lut_entry = lookup_table[rs];
            %(a_name)s_data[threadIdx.y][threadIdx.x].f4 =
                I[(c * input_channel_size)  + lut_entry.x].f4;
            %(b_name)s_data[threadIdx.y][threadIdx.x].f4 =
                %(check_filter_cond)s F[(c * filter_channel_size) + lut_entry.y].f4;
        }

        __syncthreads();

        //Accumulate product for last iteration
        #pragma unroll
        for(int i = 0; i < NUM_ROWS; i++)
        {
            Matrix load_row;
            Matrix load_col;

            load_row.f4 = %(a_name)s_data[i][threadIdx.x].f4;
            load_col.f4 = %(b_name)s_data[i][threadIdx.y].f4;

            //Accumulate product
            #pragma unroll
            for(int q_offset = 0; q_offset < REG_TILE_Y; q_offset++)
            {
                #pragma unroll
                for(int p_offset = 0; p_offset < REG_TILE_X; p_offset++)
                {
                    result[q_offset].f[p_offset] += (load_row.f[p_offset] * load_col.f[q_offset]);
                }
            }
        }
    }

    //Store result
    filter_id = (filter_id + threadIdx.y) << 2;
    if(filter_id < K)
    {
        image_id += threadIdx.x;

        #pragma unroll
        for(int q_offset = 0; q_offset < 4; q_offset++)
        {
            if(filter_id < K)
            {
                int out_index = (filter_id * output_filter_size) + output_pixel + image_id;
                %(bsum_code)s

                Matrix cur_value = {0};
                if(beta > 0.0f)
                {
                    cur_value.f4 = O[out_index].f4;
                }

                result[q_offset].f[0] = (result[q_offset].f[0] * alpha) + (cur_value.f[0] * beta);
                result[q_offset].f[1] = (result[q_offset].f[1] * alpha) + (cur_value.f[1] * beta);
                result[q_offset].f[2] = (result[q_offset].f[2] * alpha) + (cur_value.f[2] * beta);
                result[q_offset].f[3] = (result[q_offset].f[3] * alpha) + (cur_value.f[3] * beta);

                O[out_index].f4 = result[q_offset].f4;
            }
            filter_id++;
        }
    }
}
"""

    update_code = r"""
{
    __shared__ Matrix %(a_name)s_data[TILE_DIM / 4][THREADS_DIM * 4 + 4];
    __shared__ Matrix %(b_name)s_data[TILE_DIM / 4][THREADS_DIM * 4 + 4];

    //TODO: Use square access pattern to image data to increase cache hits
    int output_pixel, filter_id;
    _idiv_magic(blockIdx.x, magic_pq, shift_pq, output_pixels, filter_id, output_pixel);
    filter_id = filter_id * TILE_DIM;
    int load_filter_id = filter_id + threadIdx.y;

    int filter_pixel_id = blockIdx.y * TILE_DIM;

    //TODO: Zig zag along x axis to increase cache hits
    int output_pixel_x, output_pixel_y;
    _idiv_magic(output_pixel, magic_q, shift_q, grid_q, output_pixel_y, output_pixel_x);

    //Compute input image and filter offsets for this pixel
    int c, rs;
    int crs = filter_pixel_id + threadIdx.y;
    _idiv_magic(crs, MAGIC_FILTER_SIZE, SHIFT_FILTER_SIZE, FILTER_SIZE, c, rs);

    int filter_x, filter_y;
    _idiv_magic32(rs, magic_s, shift_s, S, filter_y, filter_x);

    int output_pixel_x_save = output_pixel_x;
    for(; output_pixel_y < P; output_pixel_y += grid_p)
    {
        for(output_pixel_x = output_pixel_x_save; output_pixel_x < Q; output_pixel_x += grid_q)
        {
            int base_x = output_pixel_x * stride_w - padding_w + filter_x;
            int base_y = output_pixel_y * stride_h - padding_h + filter_y;
            int crs_in_bounds = (c < C) && (base_x >= 0) && (base_x < W) &&
                                (base_y >= 0) && (base_y < H);
            int input_pixel = W * base_y + base_x;
            output_pixel = output_pixel_x + (Q * output_pixel_y);

            //Pre-multiply offset to simplify indexing
            input_pixel = (input_pixel * N) >> 2;
            output_pixel = (output_pixel * N) >> 2;

            //Evaluate gemm with outer product dimensions N, K and inner product CRS
            Matrix result[ITEMS_PER_THREAD] = {0};

            //Load image data and transpose into shared mem
            //TODO: pad shared memory to avoid bank conflicts
            Matrix buffer;
            buffer.f4 = crs_in_bounds ?
                        I[(c * input_channel_size) + input_pixel + threadIdx.x].f4 :
                        make_float4(0, 0, 0, 0);
            %(a_name)s_data[threadIdx.x][ 0 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[0];
            %(a_name)s_data[threadIdx.x][ 8 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[1];
            %(a_name)s_data[threadIdx.x][16 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[2];
            %(a_name)s_data[threadIdx.x][24 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[3];

            //Load error data and transpose into shared mem
            buffer.f4 = (load_filter_id < K) ?
                        F[(load_filter_id * output_filter_size) + output_pixel + threadIdx.x].f4 :
                        make_float4(0, 0, 0, 0);
            %(b_name)s_data[threadIdx.x][ 0 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[0];
            %(b_name)s_data[threadIdx.x][ 8 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[1];
            %(b_name)s_data[threadIdx.x][16 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[2];
            %(b_name)s_data[threadIdx.x][24 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[3];

            //Iterate over entire minibatch
            for(int n = threadIdx.x + (TILE_DIM >> 2); n < (N >> 2); n += (TILE_DIM >> 2))
            {
                __syncthreads();

                #pragma unroll
                for(int i = 0; i < (TILE_DIM >> 2); i++)
                {
                    Matrix row_image;
                    Matrix row_error;

                    row_image.f4 =
                        %(a_name)s_data[i][((threadIdx.y & 3) << 3) | threadIdx.y >> 2].f4;
                    row_error.f4 =
                        %(b_name)s_data[i][((threadIdx.y & 3) << 3) | threadIdx.x].f4;

                    //Accumulate product
                    #pragma unroll
                    for(int q_offset = 0; q_offset < ITEMS_PER_THREAD; q_offset++)
                    {
                        #pragma unroll
                        for(int p_offset = 0; p_offset < ITEMS_PER_THREAD; p_offset++)
                        {
                            result[p_offset].f[q_offset] +=
                                (row_image.f[p_offset] * row_error.f[q_offset]);
                        }
                    }
                }

                __syncthreads();

                //Load image data and transpose into shared mem
                buffer.f4 = crs_in_bounds ?
                    I[(c * input_channel_size) + input_pixel + n].f4 :
                    make_float4(0, 0, 0, 0);
                %(a_name)s_data[threadIdx.x][ 0 | threadIdx.y >> 2].f[threadIdx.y & 3] =
                    buffer.f[0];
                %(a_name)s_data[threadIdx.x][ 8 | threadIdx.y >> 2].f[threadIdx.y & 3] =
                    buffer.f[1];
                %(a_name)s_data[threadIdx.x][16 | threadIdx.y >> 2].f[threadIdx.y & 3] =
                    buffer.f[2];
                %(a_name)s_data[threadIdx.x][24 | threadIdx.y >> 2].f[threadIdx.y & 3] =
                    buffer.f[3];

                //Load error data and transpose into shared mem
                buffer.f4 = (load_filter_id < K) ?
                    F[(load_filter_id * output_filter_size) + output_pixel + n].f4 :
                    make_float4(0, 0, 0, 0);
                %(b_name)s_data[threadIdx.x][ 0 | threadIdx.y >> 2].f[threadIdx.y & 3] =
                    buffer.f[0];
                %(b_name)s_data[threadIdx.x][ 8 | threadIdx.y >> 2].f[threadIdx.y & 3] =
                    buffer.f[1];
                %(b_name)s_data[threadIdx.x][16 | threadIdx.y >> 2].f[threadIdx.y & 3] =
                    buffer.f[2];
                %(b_name)s_data[threadIdx.x][24 | threadIdx.y >> 2].f[threadIdx.y & 3] =
                    buffer.f[3];
            }

            __syncthreads();

            //Accumulate product for last iteration
            #pragma unroll
            for(int i = 0; i < (TILE_DIM >> 2); i++)
            {
                Matrix row_image;
                Matrix row_error;

                row_image.f4 = %(a_name)s_data[i][((threadIdx.y & 3) << 3) | threadIdx.y >> 2].f4;
                row_error.f4 = %(b_name)s_data[i][((threadIdx.y & 3) << 3) | threadIdx.x].f4;

                //Accumulate product
                #pragma unroll
                for(int q_offset = 0; q_offset < ITEMS_PER_THREAD; q_offset++)
                {
                    #pragma unroll
                    for(int p_offset = 0; p_offset < ITEMS_PER_THREAD; p_offset++)
                    {
                        result[p_offset].f[q_offset] +=
                            (row_image.f[p_offset] * row_error.f[q_offset]);
                    }
                }
            }

            //Reduce result between threads in warp
            Matrix outbound;
            int warp_y = threadIdx.y & 3;
            int warp_id = threadIdx.x + (threadIdx.y << 3);
            buffer.f4 = (warp_y == 0) ? result[0].f4 :
                        (warp_y == 1) ? result[1].f4 :
                        (warp_y == 2) ? result[2].f4 :
                        result[3].f4;

            outbound.f4 = (warp_y == 0) ? result[3].f4 :
                          (warp_y == 1) ? result[0].f4 :
                          (warp_y == 2) ? result[1].f4 :
                          result[2].f4;
            buffer.f[0] += __shfl(outbound.f[0], warp_id + 8);
            buffer.f[1] += __shfl(outbound.f[1], warp_id + 8);
            buffer.f[2] += __shfl(outbound.f[2], warp_id + 8);
            buffer.f[3] += __shfl(outbound.f[3], warp_id + 8);

            outbound.f4 = (warp_y == 0) ? result[2].f4 :
                          (warp_y == 1) ? result[3].f4 :
                          (warp_y == 2) ? result[0].f4 :
                          result[1].f4;
            buffer.f[0] += __shfl(outbound.f[0], warp_id + 16);
            buffer.f[1] += __shfl(outbound.f[1], warp_id + 16);
            buffer.f[2] += __shfl(outbound.f[2], warp_id + 16);
            buffer.f[3] += __shfl(outbound.f[3], warp_id + 16);

            outbound.f4 = (warp_y == 0) ? result[1].f4 :
                          (warp_y == 1) ? result[2].f4 :
                          (warp_y == 2) ? result[3].f4 :
                          result[0].f4;
            buffer.f[0] += __shfl(outbound.f[0], warp_id + 24);
            buffer.f[1] += __shfl(outbound.f[1], warp_id + 24);
            buffer.f[2] += __shfl(outbound.f[2], warp_id + 24);
            buffer.f[3] += __shfl(outbound.f[3], warp_id + 24);

            //Store result
            int idx_filter_id = filter_id + (threadIdx.x << 2);
            if(idx_filter_id < K && crs_in_bounds)
            {
                int out_index = (c * filter_channel_size) + (((rs * K) + (idx_filter_id)) >> 2);

                atomicAdd(&O[out_index].f[0], buffer.f[0]);
                atomicAdd(&O[out_index].f[1], buffer.f[1]);
                atomicAdd(&O[out_index].f[2], buffer.f[2]);
                atomicAdd(&O[out_index].f[3], buffer.f[3]);
            }
        }
    }
}
"""
    if operation == "update":
        code = header_code + update_code
    else:
        code = header_code + code

    magic = magic64(filter_size)

    code = code % {
        "filter_size": filter_size,
        "magic_filter_size": magic[0],
        "shift_filter_size": magic[1],
        "type": _ew_types[dtype]["type"],
        "lut_code": lut_code,
        "bsum_code": bsum_code,
        "operation": operation,
        "a_name": a_name,
        "b_name": b_name,
        "filter_load_cond": filter_load_cond,
        "check_filter_cond": check_filter_cond
    }

    options = ["--use_fast_math"]
    if debug and operation == "bprop":
        options = options + ["-g", "-G"]
    module = SourceModule(code, options=options)

    kernel = module.get_function("conv_" + operation)
    kernel.prepare("ffPPPPIIIIIIIIIIIIIIIIIIIIIIIIIIII")
    kernel.name = "conv_" + operation
    return kernel