Exemplo n.º 1
0
def _get_compound_kernel(type_args, compute_capability):
    """
    generate compound kernel for the optree from type_args
    """

    # from the stack, rebuild a mutable tree
    tree = _build_tree(type_args)
    # _print_tree(tree)
    # exit()

    # split all reductions and post reduction scalar operations out of the tree
    # sub-trees are converted to stacks and pushed onto stages list
    stages = _split_stages(tree)
    # _print_tree(tree)
    # exit()

    # set the final stage type to type of output (scalar or elementwise)
    last_stage = "red_out" if tree[1] == 1 else "ew_out"
    # convert the remainder of tree to stack
    stages.append((last_stage, _post_order(tree)))

    # for stage, stage_data in enumerate(stages):
    #     print stage_data[0], stage
    #     for s in stage_data[1]: print s
    #     print
    # exit()

    stack = list()
    placeholders = list()
    stage_out_reg = dict()
    arg_dict = dict()
    array_ids = set()
    fp16In = False
    rand_init = False
    rand_func = False
    threads = type_args[-1][3]
    template = _ew_template
    template_vals = {
        "threads": threads,
        "name": _get_kernel_name(),
        "common": list(),
        "inits": list(),
        "finish": list(),
    }

    for stage, stage_data in enumerate(stages):

        stage_type, stage_stack = stage_data
        new_placeholders = list()

        # build out the template as we process stages
        if stage_type == "reduction":

            new_placeholders.append("loads%d" % stage)
            new_placeholders.append("ops%d" % stage)
            new_placeholders.append("shfl_red%d" % stage)
            template += _stage_template["loop"].format(stage)
            if threads > 32:
                new_placeholders.append("var_red%d" % stage)
                new_placeholders.append("share1_red%d" % stage)
                new_placeholders.append("share2_red%d" % stage)
                template += _stage_template["red"].format(stage)
            else:
                template += _stage_template["red32"].format(stage)

        elif stage_type == "scalar":

            new_placeholders.append("ops%d" % stage)
            template += _stage_template["red_ops"].format(stage)

        elif stage_type == "red_out":

            new_placeholders.append("ops%d" % stage)
            template += _stage_template["red_out"].format(stage)

        else:  # ew_out

            new_placeholders.append("loads%d" % stage)
            new_placeholders.append("ops%d" % stage)
            template += _stage_template["loop"].format(stage)

        for key in new_placeholders:
            template_vals[key] = []
        placeholders.extend(new_placeholders)

        for arg_i, arg in enumerate(stage_stack):

            arg_type, arg_id = arg[0:2]

            # Array operands
            if arg_type is ng.GPUTensor:

                dtype, take_axis = arg[2:4]

                is_out_tensor = True if stage == len(
                    stages) - 1 and arg_i == 0 else False

                # first arg is output array, don't put on stack
                if is_out_tensor:
                    out_dtype = dtype
                    out_take = take_axis
                else:
                    stack.append("a%d" % arg_id)

                # 0: arg_id, 1: stage, 2: type, 3: cvt
                ew_dtype = _ew_types[dtype]
                fmt = (arg_id, stage, ew_dtype["type"], ew_dtype["cvt"])

                # First time we see a tensor initialize everything
                if arg_id not in array_ids:

                    array_ids.add(arg_id)
                    array_ids.add((arg_id, stage))

                    sig = "Pii"
                    if take_axis > 0:
                        sig += "P"

                    # output tensor
                    if is_out_tensor:
                        ew_out = _ew_strings["out%d" % take_axis]
                        arguments = ew_out["arguments"].format(*fmt)
                        template_vals["inits"].append(
                            ew_out["inits"].format(*fmt))
                    # input tensors
                    else:
                        ew_in = _ew_strings["in%d" % take_axis]
                        loads = "loads%d" % stage
                        arguments = ew_in["arguments"].format(*fmt)
                        template_vals["inits"].append(
                            ew_in["inits"].format(*fmt))
                        template_vals[loads].append(
                            ew_in["loads"].format(*fmt))

                    if dtype == 'f2' and not fp16In:
                        template_vals["common"].append(_common_fp16_to_fp32)
                        fp16In = True

                    arg_dict[arg] = (sig, arguments)

                # Subsequent times we see a tensor just initialize inits and
                # loads
                elif (arg_id, stage) not in array_ids:
                    array_ids.add((arg_id, stage))
                    ew_in = _ew_strings["in%d" % take_axis]
                    loads = "loads%d" % stage
                    template_vals["inits"].append(ew_in["inits"].format(*fmt))
                    template_vals[loads].append(ew_in["loads"].format(*fmt))

            # Constant operands
            elif arg_type is float:

                stack.append("c%d" % arg_id)
                if arg not in arg_dict:
                    arg_dict[arg] = (
                        "f", _ew_strings["const"]["arguments"].format(arg_id))

            # Operations (arg_type = op_name)
            else:

                if arg_type == "assign":

                    ops = "ops%d" % stage

                    # loop end condition for last stage
                    sig = "i"
                    arguments = ["const int n%d" % stage]

                    # rounding mode
                    if arg[2]:
                        mode = "random"
                        sig += "i"
                        arguments.append("const int mantissa_bits")
                        if not rand_init:
                            rand_init = _init_rand(template_vals)
                        template_vals["inits"].append(_init_rand_round_func)
                    else:
                        mode = "nearest"

                    arg_dict[arg] = (sig, ", ".join(arguments))

                    out_val = stack.pop()
                    # if the last stack value came from an argmax/min just do
                    # implicit type conversion
                    if out_val[0] == "i" and out_dtype[0] in "iu":
                        ew_round = None
                    else:
                        ew_round = _ew_strings["round"][
                            mode].get(out_dtype, None)
                        ew_common = _common_round[mode].get(out_dtype, None)
                        if ew_common:
                            template_vals["common"].append(ew_common)

                    if ew_round:
                        round_val = "r%d" % arg_id
                        template_vals[ops].append(
                            ew_round.format(round_val, out_val))
                    else:
                        round_val = out_val

                    template_vals[ops].append(
                        _ew_strings["out%d" % out_take]["output"].format(round_val))

                elif arg in stage_out_reg:

                    stack.append(stage_out_reg[arg])

                elif arg_type in _float_ops:

                    if len(template_vals["name"]) < 16:
                        template_vals["name"].append(arg_type)

                    ops = "ops%d" % stage

                    (num_ops, op_code) = _float_ops[arg_type]

                    if arg_type == "rand":
                        if not rand_init:
                            rand_init = _init_rand(template_vals)
                        if not rand_func:
                            template_vals["common"].append(_common_frand)
                            rand_func = True

                    op_list = ["r%d" % arg_id]

                    # build the operands from the stack
                    for i in range(num_ops):
                        op_list.append(stack.pop())

                    if arg_type == "onehot":

                        hot_axis = arg[2]
                        test_val = "i" if hot_axis else "bid"

                        ew_in = _ew_strings[arg_type + native_str(hot_axis)]
                        loads = "loads%d" % stage
                        template_vals["inits"].append(
                            ew_in["inits"].format(arg_id))
                        template_vals[loads].append(
                            ew_in["loads"].format(arg_id))
                        op_list.append("onehot%d" % arg_id)
                        op_list.append(test_val)

                        arg_dict[arg] = (
                            "P", ew_in["arguments"].format(arg_id))

                    template_vals[ops].append(op_code.format(*op_list))

                    # if this is the last op on the current stack, store its register stage
                    # in the stage output dict
                    if arg_i == len(stage_stack) - 1:
                        stage_out_reg[arg] = op_list[0]
                    # otherwise push the reg onto the stack as normal
                    else:
                        stack.append(op_list[0])

                elif arg_type in _reduction_ops:

                    if len(template_vals["name"]) < 16:
                        template_vals["name"].append(arg_type)

                    # loop end condition for current stage
                    # add regardless of duplicate reduction stage
                    arg_dict[arg] = ("i", "const int n%d" % stage)

                    # avoid float conversion for argmax/min
                    reg = "i" if "arg" == arg_type[0:3] else "r"

                    ops = "ops%d" % stage
                    shfl_red = "shfl_red%d" % stage
                    red_arg = "%s%d" % (reg, arg_id)
                    red_strings = _reduction_ops[arg_type]
                    stack_arg = stack.pop()

                    template_vals["inits"].append(
                        red_strings["inits"].format(red_arg))
                    template_vals[ops].append(
                        red_strings["ops"].format(red_arg, stack_arg))
                    template_vals[shfl_red].append(
                        red_strings["shfl_red"].format(red_arg))
                    if threads > 32:
                        var_red = "var_red%d" % stage
                        shr1_red = "share1_red%d" % stage
                        shr2_red = "share2_red%d" % stage
                        template_vals[var_red].append(red_arg)
                        template_vals[shr1_red].append(
                            red_strings["share1_red"].format(red_arg))
                        template_vals[shr2_red].append(
                            red_strings["share2_red"].format(red_arg))

                    # reduction ops are always the last on the stack
                    # just store the register state in the stage output dict
                    stage_out_reg[arg] = red_arg

                else:
                    raise ValueError("Bad op type.")

    if (compute_capability[0] == 3 and compute_capability[1] < 5) or compute_capability[0] < 3:
        template_vals["common"].append(_common_kepler)

    template += _fin_template

    # since we reorderd the operations we need to generate the argument list
    # in the original order
    sig = "P"
    arguments = list()
    unused = 1
    for arg in type_args:
        params = arg_dict.get(arg, False)
        if params:
            sig += params[0]
            arguments.append(params[1])
            del arg_dict[arg]
        # fill in the loop counter for the duplicate reductions that were
        # removed
        elif arg[0] in _reduction_ops:
            sig += "i"
            arguments.append("const int unused%d" % unused)
            unused += 1

    # convert lists to strings
    template_vals["name"] = "_".join(template_vals["name"])
    template_vals["common"] = "\n".join(template_vals["common"])
    template_vals["arguments"] = ",\n    ".join(arguments)
    template_vals["inits"] = "\n    ".join(template_vals["inits"])
    template_vals["finish"] = "\n".join(template_vals["finish"])

    # add the dynamic placeholders: loads#, ops#, reduction#
    for key in placeholders:
        template_vals[key] = "\n        ".join(template_vals[key])

    # populate the template
    code = template % template_vals

    # debugging:
    # print "Compiling %s" % template_vals["name"]
    # f = open("kernel.cu", "w")
    # f = open("%s.cu" % template_vals["name"], "w")
    # print >>f, code
    # f.close()

    # ,"-G" , keep=False
    # module = SourceModule(code, options=["--use_fast_math"])
    module = SourceModule(code, options=[])
    kernel = module.get_function(template_vals["name"])
    kernel.name = template_vals["name"]
    kernel.prepare(sig)

    return kernel
Exemplo n.º 2
0
def _get_compensated_sum_kernel(dtype, rounding):

    _compensated_sum = r"""

%(common)s

__global__ void compensated_sum(unsigned* rand_state,
          %(type)s* a_sum,
          %(type)s* a_cmp,
    const %(type)s* a_add,
    float cmp_scale, float add_scale,
    int row_strd, int col_strd, int n, int mantissa_bits)
{
    const int tid = threadIdx.x;
    const int bid = blockIdx.x;

    int offset = bid * row_strd + tid * col_strd;
    int inc    = 32 * col_strd;

    a_sum += offset;
    a_cmp += offset;
    a_add += offset;

    %(inits)s

    for (int i = tid; i < n; i += 32)
    {
        float s32 = %(cvt)s(__ldg((const %(type)s*)a_sum));
        float c32 = %(cvt)s(__ldg((const %(type)s*)a_cmp));
        float a32 = %(cvt)s(__ldg(a_add));

        // Adjust amount to add by previous compensation
        float y32 = a32 * add_scale - c32 * cmp_scale;

        // Do the accumulation and truncate to the storage type
        float rnd_sum = s32 + y32;
        %(rnd_sum)s

        // Convert accumulation back to fp32 so we can do more math on it
        float t32 = %(cvt)s(t16);

        // recover the low order bits that were lost in the truncation
        float rnd_cmp = (t32 - s32) - y32;
        %(rnd_cmp)s

        *a_sum = t16;
        *a_cmp = c16;

        a_sum += inc;
        a_cmp += inc;
        a_add += inc;
    }
    %(finish)s
}
"""
    template_vals = dict()
    for key in ("common", "inits", "finish"):
        template_vals[key] = ""

    if dtype == "f2":
        template_vals["common"] += _common_fp16_to_fp32

    if rounding:
        template_vals["common"] += _common_urand_gen
        template_vals["common"] += _common_round["nearest"].get(dtype, "")
        template_vals["inits"] += _init_rand_func + _init_rand_round_func
        template_vals["finish"] += _finish_rand_func
        mode = "random"
    else:
        mode = "nearest"

    template_vals["common"] += _common_round[mode].get(dtype, "")

    template_vals["type"] = _ew_types[dtype]["type"]
    template_vals["cvt"] = _ew_types[dtype]["cvt"]

    no_op = "float {0} = {1};"

    rnd_sum = _ew_strings["round"][mode].get(dtype, no_op)
    rnd_cmp = _ew_strings["round"]["nearest"].get(dtype, no_op)

    template_vals["rnd_sum"] = rnd_sum.format("t16", "rnd_sum")
    template_vals["rnd_cmp"] = rnd_cmp.format("c16", "rnd_cmp")

    code = _compensated_sum % template_vals

    # f = open("compensated_sum.cu", "w")
    # print >>f, code
    # f.close()

    module = SourceModule(code)
    kernel = module.get_function("compensated_sum")
    kernel.prepare("PPPPffiiii")
    return kernel
Exemplo n.º 3
0
def _get_compensated_sum_kernel(dtype, rounding):

    _compensated_sum = r"""

%(common)s

__global__ void compensated_sum(unsigned* rand_state,
          %(type)s* a_sum,
          %(type)s* a_cmp,
    const %(type)s* a_add,
    float cmp_scale, float add_scale,
    int row_strd, int col_strd, int n, int mantissa_bits)
{
    const int tid = threadIdx.x;
    const int bid = blockIdx.x;

    int offset = bid * row_strd + tid * col_strd;
    int inc    = 32 * col_strd;

    a_sum += offset;
    a_cmp += offset;
    a_add += offset;

    %(inits)s

    for (int i = tid; i < n; i += 32)
    {
        float s32 = %(cvt)s(__ldg((const %(type)s*)a_sum));
        float c32 = %(cvt)s(__ldg((const %(type)s*)a_cmp));
        float a32 = %(cvt)s(__ldg(a_add));

        // Adjust amount to add by previous compensation
        float y32 = a32 * add_scale - c32 * cmp_scale;

        // Do the accumulation and truncate to the storage type
        float rnd_sum = s32 + y32;
        %(rnd_sum)s

        // Convert accumulation back to fp32 so we can do more math on it
        float t32 = %(cvt)s(t16);

        // recover the low order bits that were lost in the truncation
        float rnd_cmp = (t32 - s32) - y32;
        %(rnd_cmp)s

        *a_sum = t16;
        *a_cmp = c16;

        a_sum += inc;
        a_cmp += inc;
        a_add += inc;
    }
    %(finish)s
}
"""
    template_vals = dict()
    for key in ("common", "inits", "finish"):
        template_vals[key] = ""

    if dtype == "f2":
        template_vals["common"] += _common_fp16_to_fp32

    if rounding:
        template_vals["common"] += _common_urand_gen
        template_vals["common"] += _common_round["nearest"].get(dtype, "")
        template_vals["inits"] += _init_rand_func + _init_rand_round_func
        template_vals["finish"] += _finish_rand_func
        mode = "random"
    else:
        mode = "nearest"

    template_vals["common"] += _common_round[mode].get(dtype, "")

    template_vals["type"] = _ew_types[dtype]["type"]
    template_vals["cvt"] = _ew_types[dtype]["cvt"]

    no_op = "float {0} = {1};"

    rnd_sum = _ew_strings["round"][mode].get(dtype, no_op)
    rnd_cmp = _ew_strings["round"]["nearest"].get(dtype, no_op)

    template_vals["rnd_sum"] = rnd_sum.format("t16", "rnd_sum")
    template_vals["rnd_cmp"] = rnd_cmp.format("c16", "rnd_cmp")

    code = _compensated_sum % template_vals

    # f = open("compensated_sum.cu", "w")
    # print >>f, code
    # f.close()

    module = SourceModule(code)
    kernel = module.get_function("compensated_sum")
    kernel.prepare("PPPPffiiii")
    return kernel
Exemplo n.º 4
0
def _get_hist_kernel(dtype_str, nbins, offset):
    """
    Build a kernel to compute a 64 bin histogram.

    Use templating to generate a customized kernel depending on the input data type.

    Memoized to avoid compiling the same kernel twice.
    """
    type_str = _ew_types[dtype_str[1:]]
    from string import Template
    code = Template(_common_fp16_to_fp32 + r"""

#define MAX(a,b) (a > b ? a : b)
#define MIN(a,b) (a < b ? a : b)

__global__ void kernel_histo (
    int* d_hist, const $in_type* a1_in,
    int strides, int size)
{
    const int tid = threadIdx.x;
    const int bid = blockIdx.x;

    __shared__ int s[$nbins];
    if(tid < $nbins){
        s[tid] = 0;
    }

    if(bid == 0 && tid < $nbins){
        d_hist[tid] = 0;
    }

    for (int i = tid + blockDim.x*bid; i < size; i += strides)
    {
        float a1 = $convert_to_float(__ldg(a1_in + i));

        float absval = fabs(a1);

        float logabs = round(log2f(absval));

        int bin = MIN($nbins-1, MAX(0, logabs-($offset)));

        atomicAdd(&s[bin], 1);

    }

    __syncthreads();

    if(tid < $nbins){
        atomicAdd(&d_hist[tid], s[tid]);
    }
}
""")

    module = SourceModule(code.substitute(in_type=type_str['type'],
                                          convert_to_float=type_str['cvt'],
                                          nbins=nbins,
                                          offset=offset),
                          options=[])
    kernel = module.get_function("kernel_histo")
    kernel.prepare("PPII")
    return kernel
Exemplo n.º 5
0
def _get_sorting_kernel(kernel_id, block_size):
    """
    Builds kernels used for sorting inputs. There are several kernels here
    corresponding to the steps in the algorithm. The algorithm works by
    determining the sorted position for each input item. This is done with
    a bucket sort algorithm, where each word_id is a bucket. The first step
    determines the size of each bucket (number of occurences of each word_id).
    Next, a prefix some is computed over the list of bucket sizes to find
    where each bucket will be placed in the output buffer. Finally, each thread
    places it's index into the correct sorted position based on the bucket
    start index (computed from the prefix sum) and that thread's offset into
    the bucket (which is taken from the output of the atomic add done in the
    first step.)

    Arguments:
        kernel_id (Integer): Which step to build the kernel for [0, 4]
        block_size (Integer): Number of threads per block for the prefix sum
            kernels.
    """
    code = r"""
#define THREADS %(threads)s
#define STORE_BLOCKSUM %(store_blocksum)s
__global__ void sort_inputs0(
        int* inputs, int* index_buffer, int* offset_buffer, int* word_counts, const int vocab_size,
        const int input_length)
{
    const int tid = threadIdx.x + (blockDim.x * blockIdx.x);
    int word_id;

    if(tid < input_length)
    {
        word_id = inputs[tid];
        offset_buffer[tid] = atomicAdd(&word_counts[word_id], 1);
    }
}

__device__ void scan(int* buffer, int* blocksum, int global_length)
{
    const int tid = (threadIdx.x << 1) + 1;
    const int gid = ((threadIdx.x + (blockIdx.x * blockDim.x)) << 1) + 1;

    __shared__ int local_counts[THREADS * 2];
    local_counts[tid] = buffer[gid];
    local_counts[tid - 1] = buffer[gid - 1];

    #pragma unroll
    for(int skip = 1; skip <= THREADS; skip <<= 1)
    {
        int mask = (skip << 1) - 1;
        if((tid & mask) == mask)
        {
            local_counts[tid] += local_counts[tid - skip];
        }

        __syncthreads();
    }

    if(tid == (THREADS * 2 - 1))
    {
#if STORE_BLOCKSUM
        blocksum[blockIdx.x] = local_counts[tid];
#endif
        local_counts[tid] = 0;
    }

    #pragma unroll
    for(int skip = THREADS; skip > 0; skip >>= 1)
    {
        int mask = (skip << 1) - 1;
        if((tid & mask) == mask)
        {
            int temp = local_counts[tid - skip];
            local_counts[tid - skip] = local_counts[tid];
            local_counts[tid] += temp;
        }

        __syncthreads();
    }

    if(gid < global_length)
    {
        buffer[gid] = local_counts[tid];
        buffer[gid - 1] = local_counts[tid - 1];
    }
}

__global__ void sort_inputs1(
        int* inputs, int* index_buffer, int* offset_buffer, int* word_counts, const int vocab_size,
        const int input_length)
{
    scan(word_counts, word_counts + vocab_size, vocab_size);
}

__global__ void sort_inputs2(
        int* inputs, int* index_buffer, int* offset_buffer, int* word_counts, const int vocab_size,
        const int input_length)
{
    scan(word_counts + vocab_size, 0, blockDim.x);
}

__global__ void sort_inputs3(
        int* inputs, int* index_buffer, int* offset_buffer, int* word_counts, const int vocab_size,
        const int input_length)
{
    const int gid = (threadIdx.x + (blockIdx.x * blockDim.x)) << 1;

    if(gid < vocab_size)
    {
        word_counts[gid] += word_counts[vocab_size + blockIdx.x];
        word_counts[gid + 1] += word_counts[vocab_size + blockIdx.x];
    }
}

__global__ void sort_inputs4(
        int* inputs, int* index_buffer, int* offset_buffer, int* word_counts, const int vocab_size,
        const int input_length)
{
    const int tid = threadIdx.x + (blockDim.x * blockIdx.x);
    int word_id;

    if(tid < input_length)
    {
        word_id = inputs[tid];
        int sorted_position = word_counts[word_id] + offset_buffer[tid];
        index_buffer[sorted_position] = tid;
    }
}
"""
    code = code % {
        "threads": block_size,
        "store_blocksum": (1 if kernel_id == 1 else 0)
    }
    module = SourceModule(code, options=["--use_fast_math"])

    function_name = "sort_inputs" + native_str(kernel_id)
    kernel = module.get_function(function_name)
    kernel.prepare("PPPPII")
    kernel.name = "sort_inputs"
    return kernel
Exemplo n.º 6
0
def _get_fprop_roipooling(clss):

    code = r"""
#define FLT_MAX 3.402823466E+38F

__global__ void fprop_roipooling(const int nthreads,
    const int num_rois, const int img_count,
    const int channels, const int height, const int width,
    const int pooled_height, const int pooled_width,
    const float* bottom_data, const float* bottom_rois, float* top_data,
    int* argmax_data, const float spatial_scale) {
    for (int index = blockIdx.x * blockDim.x + threadIdx.x; \
        index < (nthreads); index += blockDim.x * gridDim.x){
        // (c, ph, pw, n) is an element in the pooled output
        int n = index % num_rois;
        int pw = (index / num_rois) % pooled_width;
        int ph = (index / num_rois / pooled_width) % pooled_height;
        int c = index / num_rois / pooled_width / pooled_height;

        bottom_rois += n * 5;
        int roi_batch_ind = bottom_rois[0];
        int roi_start_w = round(bottom_rois[1] * spatial_scale);
        int roi_start_h = round(bottom_rois[2] * spatial_scale);
        int roi_end_w = round(bottom_rois[3] * spatial_scale);
        int roi_end_h = round(bottom_rois[4] * spatial_scale);

        // Force malformed ROIs to be 1x1
        int roi_width = max(roi_end_w - roi_start_w + 1, 1);
        int roi_height = max(roi_end_h - roi_start_h + 1, 1);
        float bin_size_h = static_cast<float>(roi_height)
                           / static_cast<float>(pooled_height);
        float bin_size_w = static_cast<float>(roi_width)
                           / static_cast<float>(pooled_width);

        int hstart = static_cast<int>(floor(static_cast<float>(ph)
                                            * bin_size_h));
        int wstart = static_cast<int>(floor(static_cast<float>(pw)
                                            * bin_size_w));
        int hend = static_cast<int>(ceil(static_cast<float>(ph + 1)
                                         * bin_size_h));
        int wend = static_cast<int>(ceil(static_cast<float>(pw + 1)
                                         * bin_size_w));

        // Add roi offsets and clip to input boundaries
        hstart = min(max(hstart + roi_start_h, 0), height);
        hend = min(max(hend + roi_start_h, 0), height);
        wstart = min(max(wstart + roi_start_w, 0), width);
        wend = min(max(wend + roi_start_w, 0), width);
        bool is_empty = (hend <= hstart) || (wend <= wstart);

        // Define an empty pooling region to be zero
        float maxval = is_empty ? 0 : -FLT_MAX;
        // If nothing is pooled, argmax = -1 causes nothing to be backprop'd
        int maxidx = -1;

        bottom_data += c * height * width * img_count;

        for (int h = hstart; h < hend; ++h) {
          for (int w = wstart; w < wend; ++w) {
            int bottom_index = h * width * img_count + w * img_count + roi_batch_ind;
            if (bottom_data[bottom_index] > maxval) {
              maxval = bottom_data[bottom_index];
              maxidx = bottom_index;
            }
          }
        }
        top_data[index] = maxval;
        argmax_data[index] = maxidx;
        // Notice the maxidx (from bottom_index) is relative to the dimension
        // (h, w, img_count) of the feature map, so max value is HWN
    }
}

"""

    module = SourceModule(code)
    kernel = module.get_function("fprop_roipooling")
    sig = "8I 4P 1f"
    kernel.prepare(sig)
    return kernel
Exemplo n.º 7
0
def _get_sorting_kernel(kernel_id, block_size):
    """
    Builds kernels used for sorting inputs. There are several kernels here
    corresponding to the steps in the algorithm. The algorithm works by
    determining the sorted position for each input item. This is done with
    a bucket sort algorithm, where each word_id is a bucket. The first step
    determines the size of each bucket (number of occurences of each word_id).
    Next, a prefix some is computed over the list of bucket sizes to find
    where each bucket will be placed in the output buffer. Finally, each thread
    places it's index into the correct sorted position based on the bucket
    start index (computed from the prefix sum) and that thread's offset into
    the bucket (which is taken from the output of the atomic add done in the
    first step.)

    Arguments:
        kernel_id (Integer): Which step to build the kernel for [0, 4]
        block_size (Integer): Number of threads per block for the prefix sum
            kernels.
    """
    code = r"""
#define THREADS %(threads)s
#define STORE_BLOCKSUM %(store_blocksum)s
__global__ void sort_inputs0(
        int* inputs, int* index_buffer, int* offset_buffer, int* word_counts, const int vocab_size,
        const int input_length)
{
    const int tid = threadIdx.x + (blockDim.x * blockIdx.x);
    int word_id;

    if(tid < input_length)
    {
        word_id = inputs[tid];
        offset_buffer[tid] = atomicAdd(&word_counts[word_id], 1);
    }
}

__device__ void scan(int* buffer, int* blocksum, int global_length)
{
    const int tid = (threadIdx.x << 1) + 1;
    const int gid = ((threadIdx.x + (blockIdx.x * blockDim.x)) << 1) + 1;

    __shared__ int local_counts[THREADS * 2];
    local_counts[tid] = buffer[gid];
    local_counts[tid - 1] = buffer[gid - 1];

    #pragma unroll
    for(int skip = 1; skip <= THREADS; skip <<= 1)
    {
        int mask = (skip << 1) - 1;
        if((tid & mask) == mask)
        {
            local_counts[tid] += local_counts[tid - skip];
        }

        __syncthreads();
    }

    if(tid == (THREADS * 2 - 1))
    {
#if STORE_BLOCKSUM
        blocksum[blockIdx.x] = local_counts[tid];
#endif
        local_counts[tid] = 0;
    }

    #pragma unroll
    for(int skip = THREADS; skip > 0; skip >>= 1)
    {
        int mask = (skip << 1) - 1;
        if((tid & mask) == mask)
        {
            int temp = local_counts[tid - skip];
            local_counts[tid - skip] = local_counts[tid];
            local_counts[tid] += temp;
        }

        __syncthreads();
    }

    if(gid < global_length)
    {
        buffer[gid] = local_counts[tid];
        buffer[gid - 1] = local_counts[tid - 1];
    }
}

__global__ void sort_inputs1(
        int* inputs, int* index_buffer, int* offset_buffer, int* word_counts, const int vocab_size,
        const int input_length)
{
    scan(word_counts, word_counts + vocab_size, vocab_size);
}

__global__ void sort_inputs2(
        int* inputs, int* index_buffer, int* offset_buffer, int* word_counts, const int vocab_size,
        const int input_length)
{
    scan(word_counts + vocab_size, 0, blockDim.x);
}

__global__ void sort_inputs3(
        int* inputs, int* index_buffer, int* offset_buffer, int* word_counts, const int vocab_size,
        const int input_length)
{
    const int gid = (threadIdx.x + (blockIdx.x * blockDim.x)) << 1;

    if(gid < vocab_size)
    {
        word_counts[gid] += word_counts[vocab_size + blockIdx.x];
        word_counts[gid + 1] += word_counts[vocab_size + blockIdx.x];
    }
}

__global__ void sort_inputs4(
        int* inputs, int* index_buffer, int* offset_buffer, int* word_counts, const int vocab_size,
        const int input_length)
{
    const int tid = threadIdx.x + (blockDim.x * blockIdx.x);
    int word_id;

    if(tid < input_length)
    {
        word_id = inputs[tid];
        int sorted_position = word_counts[word_id] + offset_buffer[tid];
        index_buffer[sorted_position] = tid;
    }
}
"""
    code = code % {
        "threads": block_size,
        "store_blocksum": (1 if kernel_id == 1 else 0)
    }
    module = SourceModule(code, options=["--use_fast_math"])

    function_name = "sort_inputs" + native_str(kernel_id)
    kernel = module.get_function(function_name)
    kernel.prepare("PPPPII")
    kernel.name = "sort_inputs"
    return kernel
Exemplo n.º 8
0
def _get_bprop_roipooling(clss):

    code = r"""
__global__ void bprop_roipooling(const int nthreads,
    const int num_rois, const int img_count,
    const int channels, const int height, const int width,
    const int pooled_height, const int pooled_width,
    const float* top_diff, const float* bottom_rois, float* bottom_diff,
    const int* argmax_data, const float spatial_scale) {
    for (int index = blockIdx.x * blockDim.x + threadIdx.x; \
        index < (nthreads); index += blockDim.x * gridDim.x){
        // (c, h, w, n) coords in bottom data on feature map
        int n = index % img_count;
        int w = (index / img_count) % width;
        int h = (index / img_count / width) % height;
        int c = index / img_count/ width / height;

        float gradient = 0;
        // Accumulate gradient over all ROIs that pooled this element
        for (int roi_n = 0; roi_n < num_rois; ++roi_n) {
          const float* offset_bottom_rois = bottom_rois + roi_n * 5;
          int roi_batch_ind = offset_bottom_rois[0];
          // Skip if ROI's batch index doesn't match n
          if (n != roi_batch_ind) {
            continue;
          }

          int roi_start_w = round(offset_bottom_rois[1] * spatial_scale);
          int roi_start_h = round(offset_bottom_rois[2] * spatial_scale);
          int roi_end_w = round(offset_bottom_rois[3] * spatial_scale);
          int roi_end_h = round(offset_bottom_rois[4] * spatial_scale);

          // Skip if ROI doesn't include (h, w)
          const bool in_roi = (w >= roi_start_w && w <= roi_end_w &&
                               h >= roi_start_h && h <= roi_end_h);
          if (!in_roi) {
            continue;
          }

          int offset = c * pooled_height * pooled_width * num_rois;
          const float* offset_top_diff = top_diff + offset;
          const int* offset_argmax_data = argmax_data + offset;

          // Compute feasible set of pooled units that could have pooled
          // this bottom unit

          // Force malformed ROIs to be 1x1
          int roi_width = max(roi_end_w - roi_start_w + 1, 1);
          int roi_height = max(roi_end_h - roi_start_h + 1, 1);

          float bin_size_h = static_cast<float>(roi_height)
                             / static_cast<float>(pooled_height);
          float bin_size_w = static_cast<float>(roi_width)
                             / static_cast<float>(pooled_width);

          int phstart = floor(static_cast<float>(h - roi_start_h) / bin_size_h);
          int phend = ceil(static_cast<float>(h - roi_start_h + 1) / bin_size_h);
          int pwstart = floor(static_cast<float>(w - roi_start_w) / bin_size_w);
          int pwend = ceil(static_cast<float>(w - roi_start_w + 1) / bin_size_w);

          phstart = min(max(phstart, 0), pooled_height);
          phend = min(max(phend, 0), pooled_height);
          pwstart = min(max(pwstart, 0), pooled_width);
          pwend = min(max(pwend, 0), pooled_width);

          for (int ph = phstart; ph < phend; ++ph) {
            for (int pw = pwstart; pw < pwend; ++pw) {
              int top_index = ph * pooled_width * num_rois + pw * num_rois + roi_n;
              int bottom_index = h * width * img_count + w * img_count + roi_batch_ind;
              if (offset_argmax_data[top_index] == bottom_index) {
                gradient += offset_top_diff[top_index];
              }
            }
          }
        }
        bottom_diff[index] = gradient;
    }
}

"""

    module = SourceModule(code)
    kernel = module.get_function("bprop_roipooling")
    sig = "8I 4P 1f"
    kernel.prepare(sig)
    return kernel
Exemplo n.º 9
0
def _get_bn_fprop_kernel(dtype, threads, compute_capability):

    if threads > 32:
        shr_code = "__shared__ float sPartials[THREADS];"
        red_code = r"""
    sPartials[tid] = xvar;
    __syncthreads();

    #pragma unroll
    for (int a = THREADS >> 1; a > 32; a >>= 1)
    {
        if ( tid < a )
            sPartials[tid] += sPartials[tid + a];
        __syncthreads();
    }
    if ( tid < 32 )
    {
        xvar = sPartials[tid] + sPartials[tid + 32];
        #pragma unroll
        for (int i = 16; i > 0; i >>= 1)
            xvar += __shfl_xor(xvar, i);

        sPartials[tid] = xvar * rcpN;
    }
    __syncthreads();
    xvar = sPartials[0];
"""
    else:
        shr_code = ""
        red_code = r"""
    #pragma unroll
    for (int i = 16; i > 0; i >>= 1)
        xvar += __shfl_xor(xvar, i);
    xvar *= rcpN;
"""

    code = r"""
#define THREADS %(threads)s

%(common)s
%(binary)s

__global__ void batchnorm_fprop (
    %(type)s* y_out, float* xvar_out, float* gmean_out, float* gvar_out,
    const %(type)s* x_in, const float* xsum_in, const float* gmean_in,
    const float* gvar_in, const float* gamma_in, const float* beta_in,
    const float eps, const float rho, const float accumbeta, const int N,
    const int relu, bool binary)
{
    %(share)s

    const int tid  = threadIdx.x;
    const int bid  = blockIdx.x;
    int offset = bid * N;

    const %(type)s* x_in0 = x_in + offset + tid;

    const float rcpN = 1.0f/(float)N;

    float xmean = __ldg(xsum_in + bid) * rcpN;

    float xvar = 0.0f;
    for (int i = tid; i < N; i += THREADS)
    {
        float x = %(cvt)s(__ldg(x_in0));
        x_in0 += THREADS;

        x -= xmean;
        if (binary) {
            xvar += shift_element(x, x, true);
        } else {
            xvar += x * x;
        }
    }
    %(red)s

    float gamma = __ldg(gamma_in + bid);
    float beta  = __ldg(beta_in  + bid);

    if ( tid == 0 )
    {
        float gmean = __ldg(gmean_in + bid);
        float gvar  = __ldg(gvar_in  + bid);

        *(xvar_out  + bid) = xvar;
        *(gmean_out + bid) = gmean * rho + (1.0f - rho) * xmean;
        *(gvar_out  + bid) = gvar  * rho + (1.0f - rho) * xvar;
    }

    float xvar_rcp_sqrt = 1.0f / sqrtf(xvar + eps);

    int start = N - (THREADS*4 - tid);
    offset += start;
    x_in   += offset;
    y_out  += offset;

    for (int i = start; i >= -THREADS*3; i -= THREADS*4)
    {
        float x0 = i >= -THREADS*0 ? %(cvt)s(__ldg(x_in + THREADS*0)) : 0.0f;
        float x1 = i >= -THREADS*1 ? %(cvt)s(__ldg(x_in + THREADS*1)) : 0.0f;
        float x2 = i >= -THREADS*2 ? %(cvt)s(__ldg(x_in + THREADS*2)) : 0.0f;
        float x3 =                   %(cvt)s(__ldg(x_in + THREADS*3));

        x_in -= THREADS*4;

        float xhat0 = 0.0f;
        float xhat1 = 0.0f;
        float xhat2 = 0.0f;
        float xhat3 = 0.0f;

        float y0 = 0.0f;
        float y1 = 0.0f;
        float y2 = 0.0f;
        float y3 = 0.0f;
        if (binary) {
            xhat0 = shift_element(x0 - xmean, xvar_rcp_sqrt, true);
            xhat1 = shift_element(x1 - xmean, xvar_rcp_sqrt, true);
            xhat2 = shift_element(x2 - xmean, xvar_rcp_sqrt, true);
            xhat3 = shift_element(x3 - xmean, xvar_rcp_sqrt, true);

            y0 = shift_element(xhat0, gamma, true) + beta;
            y1 = shift_element(xhat1, gamma, true) + beta;
            y2 = shift_element(xhat2, gamma, true) + beta;
            y3 = shift_element(xhat3, gamma, true) + beta;
        } else {
            xhat0 = (x0 - xmean) * xvar_rcp_sqrt;
            xhat1 = (x1 - xmean) * xvar_rcp_sqrt;
            xhat2 = (x2 - xmean) * xvar_rcp_sqrt;
            xhat3 = (x3 - xmean) * xvar_rcp_sqrt;

            y0 = xhat0 * gamma + beta;
            y1 = xhat1 * gamma + beta;
            y2 = xhat2 * gamma + beta;
            y3 = xhat3 * gamma + beta;
        }

        if (relu)
        {
            y0 = fmaxf(y0, 0.0f);
            y1 = fmaxf(y1, 0.0f);
            y2 = fmaxf(y2, 0.0f);
            y3 = fmaxf(y3, 0.0f);
        }

        %(y0_out)s
        %(y1_out)s
        %(y2_out)s
        %(y3_out)s
        if (accumbeta == 0.0)
        {
            if (i >= -THREADS*0) *(y_out + THREADS*0) = y0_val;
            if (i >= -THREADS*1) *(y_out + THREADS*1) = y1_val;
            if (i >= -THREADS*2) *(y_out + THREADS*2) = y2_val;
                                 *(y_out + THREADS*3) = y3_val;
        }
        else
        {
            if (i >= -THREADS*0) *(y_out + THREADS*0) = y_out[THREADS*0] * accumbeta + y0_val;
            if (i >= -THREADS*1) *(y_out + THREADS*1) = y_out[THREADS*1] * accumbeta + y1_val;
            if (i >= -THREADS*2) *(y_out + THREADS*2) = y_out[THREADS*2] * accumbeta + y2_val;
                                 *(y_out + THREADS*3) = y_out[THREADS*3] * accumbeta + y3_val;
        }
        y_out -= THREADS*4;
    }
}
"""
    out_code = _ew_strings["round"]["nearest"].get(dtype, "float {0} = {1};")
    common_code  = _common_round["nearest"].get(dtype, "")
    if dtype == "f2":
        common_code += _common_fp16_to_fp32

    if (compute_capability[0] == 3 and compute_capability[1] < 5) or compute_capability[0] < 3:
        common_code += _common_kepler

    code = code % {
        "common"    : common_code,
        "binary"    : shift_element(),
        "share"     : shr_code,
        "red"       : red_code,
        "threads"   : threads,
        "type"      : _ew_types[dtype]["type"],
        "cvt"       : _ew_types[dtype]["cvt"],
        "y0_out"    : out_code.format("y0_val",     "y0"),
        "y1_out"    : out_code.format("y1_val",     "y1"),
        "y2_out"    : out_code.format("y2_val",     "y2"),
        "y3_out"    : out_code.format("y3_val",     "y3"),
    }
    module = SourceModule(code, options=["--use_fast_math"])
    kernel = module.get_function("batchnorm_fprop")
    kernel.prepare("PPPPPPPPPPfffIII")
    kernel.name = "batchnorm_fprop"
    return kernel
Exemplo n.º 10
0
def _get_conv_kernel(dtype, filter_size, bsum, operation, filter_bounds_check=False, debug=False):
    """
    Builds the convolution kernel for a specified filter size.

    Arguments:
        dtype (np.dtype): The data type which the kernel will operate on.
        filter_size (int): Total number of elements per filter (R * S)
        bsum (boolean): If set to true, kernel will include code to compute
            batch sum during fprop
        operation (string): Determines which kernel to build. options follow:
            'fprop': Forward propagation of activations.
            'bprop': Backward propagation of error.
            'update': Computes gradients for filter weights based on error and inputs.
        filter_bounds_check (boolean): Checks if filter weight is in bounds when K is
            not a multiple of 32.
        debug (boolean): When set to true, kernels will be compiled with debug symbols.
    """
    assert operation in ["fprop", "bprop", "update"]
    if operation == "fprop" or operation == "update":
        lut_code = r"""
    if(tid < 32)
    {
        int rs = tid;
        int base_x, base_y;

        base_x = output_pixel_x * stride_w - padding_w;
        base_y = output_pixel_y * stride_h - padding_h;

        unsigned int mask = (1 << tid) - 1;

        while(rs < FILTER_SIZE)
        {
            int filter_x, filter_y;
            _idiv_magic32(rs, magic_s, shift_s, S, filter_y, filter_x);

            int index_x = base_x + filter_x * dilation_w;
            int index_y = base_y + filter_y * dilation_h;

            //Check if the index is valid
            int in_bounds = (index_x >= 0 && index_x < W && index_y >= 0 && index_y < H);
            unsigned int threads_in_bounds = __ballot(in_bounds);

            //Store lookup table entry
            if(in_bounds)
            {
                int2 lut_entry;
                lut_entry.x = ((index_y * W + index_x) * N) >> 2;
                lut_entry.y = (rs * K) >> 2;

                int index = lut_size_local + __popc(threads_in_bounds & mask);
                lookup_table[index] = lut_entry;
            }

            lut_size_local += __popc(threads_in_bounds);

            rs += 32;
        }
    }
"""
    elif operation == "bprop":
        lut_code = r"""
    if(tid < 32)
    {
        int rs = tid;
        int base_q, base_p;

        base_q = output_pixel_x - ((S - 1) * dilation_w - padding_w);
        base_p = output_pixel_y - ((R - 1) * dilation_h - padding_h);

        unsigned int mask = (1 << tid) - 1;

        while(rs < FILTER_SIZE)
        {
            int filter_x, filter_y;
            _idiv_magic32(rs, magic_s, shift_s, S, filter_y, filter_x);

            int index_q = base_q + filter_x * dilation_w;
            int index_p = base_p + filter_y * dilation_h;

            //Check if the index is valid
            int in_bounds = (((index_q % stride_w) | (index_p % stride_h)) == 0);
            index_q /= stride_w;
            index_p /= stride_h;
            in_bounds = in_bounds && (index_q >= 0 && index_q < W
                                      && index_p >= 0 && index_p < H);
            unsigned int threads_in_bounds = __ballot(in_bounds);

            //Store lookup table entry
            if(in_bounds)
            {
                int2 lut_entry;
                lut_entry.x = (((index_p * W) + index_q) * N) >> 2;
                lut_entry.y = (rs * K) >> 2;

                int index = lut_size_local + __popc(threads_in_bounds & mask);
                lookup_table[index] = lut_entry;
            }

            lut_size_local += __popc(threads_in_bounds);

            rs += 32;
        }
    }
"""
    if bsum:
        bsum_code = r"""
            float local_bsum = result[q_offset].f[0] + result[q_offset].f[1] +
                               result[q_offset].f[2] + result[q_offset].f[3];
            atomicAdd(&bsum[filter_id], local_bsum);
"""
    else:
        bsum_code = ""

    if operation == "update":
        a_name = "image"
        b_name = "error"
    else:
        if operation == "fprop":
            a_name = "image"
            b_name = "filter"
        elif operation == "bprop":
            a_name = "error"
            b_name = "filter"

    if filter_bounds_check:
        filter_load_cond = "int filter_load_in_bounds = (((filter_id + threadIdx.x) << 2) < K);"
        check_filter_cond = "(!filter_load_in_bounds) ? make_float4(0, 0, 0, 0) :"
    else:
        filter_load_cond = ""
        check_filter_cond = ""

    header_code = r"""
#define TILE_DIM            32
#define ITEMS_PER_THREAD    4
#define THREADS_DIM         8

#define REG_TILE_X          4
#define REG_TILE_Y          4
#define THREADS_DIM_X       8
#define THREADS_DIM_Y       8
#define SM_TILE_X           (REG_TILE_X * THREADS_DIM_X)
#define SM_TILE_Y           (REG_TILE_Y * THREADS_DIM_Y)

#define NUM_ROWS            8
#define FILTER_SIZE         %(filter_size)s
#define MAGIC_FILTER_SIZE   %(magic_filter_size)s
#define SHIFT_FILTER_SIZE   %(shift_filter_size)s

typedef union Matrix {
    %(type)s4 f4;
    %(type)s f[4];
} Matrix;

__device__ inline void _idiv_fast(int numerator, int denominator, float rcp,
                                 int& result, int& remainder)
{
    result = (int)((float)numerator * rcp);
    remainder = numerator - (result * denominator);
    result = (remainder >= denominator) ? (result + 1) : result;
    remainder = (remainder >= denominator) ? (remainder - denominator) : remainder;
}

__device__ inline void _idiv_magic(int numerator, unsigned int magic, unsigned int shift,
                                   int denominator, int& result, int& remainder)
{
    if(magic == 1)
    {
        result = numerator >> shift;
    }
    else
    {
        unsigned long long res64 = numerator * (unsigned long long)magic;
        result = ((int)(res64 >> 32) >> shift);
    }
    remainder = numerator - (result * denominator);
}

__device__ inline void _idiv_magic32(int numerator, unsigned int magic, unsigned int shift,
                                     int denominator, int& result, int& remainder)
{
    if(magic == 1)
    {
        result = numerator >> shift;
    }
    else
    {
        result = ((numerator * magic) >> shift);
    }
    remainder = numerator - (result * denominator);
}

//Note: N and K must be multiples of 4
//blockIdx.x is gemm tile id (K dimension) and output pixel id
//blockIdx.y is gemm tile id (N dimension)
//threadIdx.x is gemm tile offset (K dimension)
//threadIdx.y is gemm tile offset (N dimension)
__global__ void conv_%(operation)s(
                           %(type)s alpha, %(type)s beta,
                           Matrix *I, Matrix *F, Matrix *O, float* bsum,
                           int C, int D, int H, int W, int N,
                           int T, int R, int S, int K,
                           int M, int P, int Q,
                           int stride_w, int stride_h, int padding_w, int padding_h,
                           int dilation_w, int dilation_h,
                           int input_channel_size, int filter_channel_size,
                           int output_filter_size,
                           int output_pixels, int grid_p, int grid_q,
                           unsigned int magic_pq, unsigned int shift_pq,
                           unsigned int magic_q, unsigned int shift_q,
                           unsigned int magic_s, unsigned int shift_s)

"""
    code = r"""
{
    __shared__ int2 lookup_table[FILTER_SIZE];
    __shared__ int lut_size;
    __shared__ Matrix %(a_name)s_data[NUM_ROWS][THREADS_DIM_X];
    __shared__ Matrix %(b_name)s_data[NUM_ROWS][THREADS_DIM_Y];

    int lut_size_local = 0;

    //TODO: Use square access pattern to image data to increase cache hits
    int output_pixel, image_id;
    _idiv_magic(blockIdx.x, magic_pq, shift_pq, output_pixels, image_id, output_pixel);
    image_id = (image_id * blockDim.x);

    //Zig zag along x axis to increase cache hits
    int temp_x, temp_y;
    _idiv_magic(output_pixel, magic_q, shift_q, Q, temp_y, temp_x);
    int output_pixel_x = (temp_y & 1) ? (Q - temp_x - 1) : temp_x;
    int output_pixel_y = temp_y;
    output_pixel = output_pixel_x + (output_pixel_y * Q);

    int filter_id = blockIdx.y * blockDim.y;
    int tid = threadIdx.x + threadIdx.y * blockDim.x;

    //Offset buffers based on thread id
    I = &(I[image_id  + threadIdx.x]);
    F = &(F[filter_id + threadIdx.x]);

    %(filter_load_cond)s

    //Compute lookup table for filter/image data
%(lut_code)s

    if(tid == 0)
    {
        lut_size = lut_size_local;
    }

    __syncthreads();

    lut_size_local = lut_size;
    Matrix result[REG_TILE_Y] = {0};
    output_pixel = (output_pixel * N) >> 2;
    if(lut_size_local > 0)
    {
        //Evaluate gemm with outer product dimensions N, K and inner product CRS
        int CRS = lut_size_local * C;

        //Compute magic numbers for division by lut_size
        float reciprocal = 1.0f / (float)lut_size_local;

        //Initialize shared mem for first block
        int crs = CRS %% NUM_ROWS;
        crs = (crs == 0) ? 8 : crs;

        int c, rs;
        _idiv_fast(CRS - threadIdx.y - 1, lut_size_local, reciprocal, c, rs);

        int2 lut_entry = ((threadIdx.y & 7) >= crs) ? make_int2(0, 0) : lookup_table[rs];
        %(a_name)s_data[threadIdx.y][threadIdx.x].f4 =
            ((threadIdx.y & 7) >= crs) ? make_float4(0, 0, 0, 0) :
            I[(c * input_channel_size)  + lut_entry.x].f4;
        %(b_name)s_data[threadIdx.y][threadIdx.x].f4 = %(check_filter_cond)s
            ((threadIdx.y & 7) >= crs) ? make_float4(0, 0, 0, 0) :
            F[(c * filter_channel_size) + lut_entry.y].f4;

        //Iterate over entire filter
        for(crs = CRS - crs - 1; crs > 0; crs -= NUM_ROWS)
        {
            __syncthreads();

            #pragma unroll
            for(int i = 0; i < NUM_ROWS; i++)
            {
                Matrix load_row;
                Matrix load_col;

                load_row.f4 = %(a_name)s_data[i][threadIdx.x].f4;
                load_col.f4 = %(b_name)s_data[i][threadIdx.y].f4;

                //Accumulate product
                #pragma unroll
                for(int q_offset = 0; q_offset < REG_TILE_Y; q_offset++)
                {
                    #pragma unroll
                    for(int p_offset = 0; p_offset < REG_TILE_X; p_offset++)
                    {
                        result[q_offset].f[p_offset] += (load_row.f[p_offset] *
                                                         load_col.f[q_offset]);
                    }
                }
            }

            __syncthreads();

            //Load new image data and filter weights
            _idiv_fast(crs - threadIdx.y, lut_size_local, reciprocal, c, rs);

            lut_entry = lookup_table[rs];
            %(a_name)s_data[threadIdx.y][threadIdx.x].f4 =
                I[(c * input_channel_size)  + lut_entry.x].f4;
            %(b_name)s_data[threadIdx.y][threadIdx.x].f4 =
                %(check_filter_cond)s F[(c * filter_channel_size) + lut_entry.y].f4;
        }

        __syncthreads();

        //Accumulate product for last iteration
        #pragma unroll
        for(int i = 0; i < NUM_ROWS; i++)
        {
            Matrix load_row;
            Matrix load_col;

            load_row.f4 = %(a_name)s_data[i][threadIdx.x].f4;
            load_col.f4 = %(b_name)s_data[i][threadIdx.y].f4;

            //Accumulate product
            #pragma unroll
            for(int q_offset = 0; q_offset < REG_TILE_Y; q_offset++)
            {
                #pragma unroll
                for(int p_offset = 0; p_offset < REG_TILE_X; p_offset++)
                {
                    result[q_offset].f[p_offset] += (load_row.f[p_offset] * load_col.f[q_offset]);
                }
            }
        }
    }

    //Store result
    filter_id = (filter_id + threadIdx.y) << 2;
    if(filter_id < K)
    {
        image_id += threadIdx.x;

        #pragma unroll
        for(int q_offset = 0; q_offset < 4; q_offset++)
        {
            if(filter_id < K)
            {
                int out_index = (filter_id * output_filter_size) + output_pixel + image_id;
                %(bsum_code)s

                Matrix cur_value = {0};
                if(beta > 0.0f)
                {
                    cur_value.f4 = O[out_index].f4;
                }

                result[q_offset].f[0] = (result[q_offset].f[0] * alpha) + (cur_value.f[0] * beta);
                result[q_offset].f[1] = (result[q_offset].f[1] * alpha) + (cur_value.f[1] * beta);
                result[q_offset].f[2] = (result[q_offset].f[2] * alpha) + (cur_value.f[2] * beta);
                result[q_offset].f[3] = (result[q_offset].f[3] * alpha) + (cur_value.f[3] * beta);

                O[out_index].f4 = result[q_offset].f4;
            }
            filter_id++;
        }
    }
}
"""

    update_code = r"""
{
    __shared__ Matrix %(a_name)s_data[TILE_DIM / 4][THREADS_DIM * 4 + 4];
    __shared__ Matrix %(b_name)s_data[TILE_DIM / 4][THREADS_DIM * 4 + 4];

    //TODO: Use square access pattern to image data to increase cache hits
    int output_pixel, filter_id;
    _idiv_magic(blockIdx.x, magic_pq, shift_pq, output_pixels, filter_id, output_pixel);
    filter_id = filter_id * TILE_DIM;
    int load_filter_id = filter_id + threadIdx.y;

    int filter_pixel_id = blockIdx.y * TILE_DIM;

    //TODO: Zig zag along x axis to increase cache hits
    int output_pixel_x, output_pixel_y;
    _idiv_magic(output_pixel, magic_q, shift_q, grid_q, output_pixel_y, output_pixel_x);

    //Compute input image and filter offsets for this pixel
    int c, rs;
    int crs = filter_pixel_id + threadIdx.y;
    _idiv_magic(crs, MAGIC_FILTER_SIZE, SHIFT_FILTER_SIZE, FILTER_SIZE, c, rs);

    int filter_x, filter_y;
    _idiv_magic32(rs, magic_s, shift_s, S, filter_y, filter_x);

    int output_pixel_x_save = output_pixel_x;
    for(; output_pixel_y < P; output_pixel_y += grid_p)
    {
        for(output_pixel_x = output_pixel_x_save; output_pixel_x < Q; output_pixel_x += grid_q)
        {
            int base_x = output_pixel_x * stride_w - padding_w + filter_x * dilation_w;
            int base_y = output_pixel_y * stride_h - padding_h + filter_y * dilation_h;
            int crs_in_bounds = (c < C) && (base_x >= 0) && (base_x < W) &&
                                (base_y >= 0) && (base_y < H);
            int input_pixel = W * base_y + base_x;
            output_pixel = output_pixel_x + (Q * output_pixel_y);

            //Pre-multiply offset to simplify indexing
            input_pixel = (input_pixel * N) >> 2;
            output_pixel = (output_pixel * N) >> 2;

            //Evaluate gemm with outer product dimensions N, K and inner product CRS
            Matrix result[ITEMS_PER_THREAD] = {0};

            //Load image data and transpose into shared mem
            //TODO: pad shared memory to avoid bank conflicts
            Matrix buffer;
            buffer.f4 = crs_in_bounds ?
                        I[(c * input_channel_size) + input_pixel + threadIdx.x].f4 :
                        make_float4(0, 0, 0, 0);
            %(a_name)s_data[threadIdx.x][ 0 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[0];
            %(a_name)s_data[threadIdx.x][ 8 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[1];
            %(a_name)s_data[threadIdx.x][16 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[2];
            %(a_name)s_data[threadIdx.x][24 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[3];

            //Load error data and transpose into shared mem
            buffer.f4 = (load_filter_id < K) ?
                        F[(load_filter_id * output_filter_size) + output_pixel + threadIdx.x].f4 :
                        make_float4(0, 0, 0, 0);
            %(b_name)s_data[threadIdx.x][ 0 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[0];
            %(b_name)s_data[threadIdx.x][ 8 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[1];
            %(b_name)s_data[threadIdx.x][16 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[2];
            %(b_name)s_data[threadIdx.x][24 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[3];

            //Iterate over entire minibatch
            for(int n = threadIdx.x + (TILE_DIM >> 2); n < (N >> 2); n += (TILE_DIM >> 2))
            {
                __syncthreads();

                #pragma unroll
                for(int i = 0; i < (TILE_DIM >> 2); i++)
                {
                    Matrix row_image;
                    Matrix row_error;

                    row_image.f4 =
                        %(a_name)s_data[i][((threadIdx.y & 3) << 3) | threadIdx.y >> 2].f4;
                    row_error.f4 =
                        %(b_name)s_data[i][((threadIdx.y & 3) << 3) | threadIdx.x].f4;

                    //Accumulate product
                    #pragma unroll
                    for(int q_offset = 0; q_offset < ITEMS_PER_THREAD; q_offset++)
                    {
                        #pragma unroll
                        for(int p_offset = 0; p_offset < ITEMS_PER_THREAD; p_offset++)
                        {
                            result[p_offset].f[q_offset] +=
                                (row_image.f[p_offset] * row_error.f[q_offset]);
                        }
                    }
                }

                __syncthreads();

                //Load image data and transpose into shared mem
                buffer.f4 = crs_in_bounds ?
                    I[(c * input_channel_size) + input_pixel + n].f4 :
                    make_float4(0, 0, 0, 0);
                %(a_name)s_data[threadIdx.x][ 0 | threadIdx.y >> 2].f[threadIdx.y & 3] =
                    buffer.f[0];
                %(a_name)s_data[threadIdx.x][ 8 | threadIdx.y >> 2].f[threadIdx.y & 3] =
                    buffer.f[1];
                %(a_name)s_data[threadIdx.x][16 | threadIdx.y >> 2].f[threadIdx.y & 3] =
                    buffer.f[2];
                %(a_name)s_data[threadIdx.x][24 | threadIdx.y >> 2].f[threadIdx.y & 3] =
                    buffer.f[3];

                //Load error data and transpose into shared mem
                buffer.f4 = (load_filter_id < K) ?
                    F[(load_filter_id * output_filter_size) + output_pixel + n].f4 :
                    make_float4(0, 0, 0, 0);
                %(b_name)s_data[threadIdx.x][ 0 | threadIdx.y >> 2].f[threadIdx.y & 3] =
                    buffer.f[0];
                %(b_name)s_data[threadIdx.x][ 8 | threadIdx.y >> 2].f[threadIdx.y & 3] =
                    buffer.f[1];
                %(b_name)s_data[threadIdx.x][16 | threadIdx.y >> 2].f[threadIdx.y & 3] =
                    buffer.f[2];
                %(b_name)s_data[threadIdx.x][24 | threadIdx.y >> 2].f[threadIdx.y & 3] =
                    buffer.f[3];
            }

            __syncthreads();

            //Accumulate product for last iteration
            #pragma unroll
            for(int i = 0; i < (TILE_DIM >> 2); i++)
            {
                Matrix row_image;
                Matrix row_error;

                row_image.f4 = %(a_name)s_data[i][((threadIdx.y & 3) << 3) | threadIdx.y >> 2].f4;
                row_error.f4 = %(b_name)s_data[i][((threadIdx.y & 3) << 3) | threadIdx.x].f4;

                //Accumulate product
                #pragma unroll
                for(int q_offset = 0; q_offset < ITEMS_PER_THREAD; q_offset++)
                {
                    #pragma unroll
                    for(int p_offset = 0; p_offset < ITEMS_PER_THREAD; p_offset++)
                    {
                        result[p_offset].f[q_offset] +=
                            (row_image.f[p_offset] * row_error.f[q_offset]);
                    }
                }
            }

            //Reduce result between threads in warp
            Matrix outbound;
            int warp_y = threadIdx.y & 3;
            int warp_id = threadIdx.x + (threadIdx.y << 3);
            buffer.f4 = (warp_y == 0) ? result[0].f4 :
                        (warp_y == 1) ? result[1].f4 :
                        (warp_y == 2) ? result[2].f4 :
                        result[3].f4;

            outbound.f4 = (warp_y == 0) ? result[3].f4 :
                          (warp_y == 1) ? result[0].f4 :
                          (warp_y == 2) ? result[1].f4 :
                          result[2].f4;
            buffer.f[0] += __shfl(outbound.f[0], warp_id + 8);
            buffer.f[1] += __shfl(outbound.f[1], warp_id + 8);
            buffer.f[2] += __shfl(outbound.f[2], warp_id + 8);
            buffer.f[3] += __shfl(outbound.f[3], warp_id + 8);

            outbound.f4 = (warp_y == 0) ? result[2].f4 :
                          (warp_y == 1) ? result[3].f4 :
                          (warp_y == 2) ? result[0].f4 :
                          result[1].f4;
            buffer.f[0] += __shfl(outbound.f[0], warp_id + 16);
            buffer.f[1] += __shfl(outbound.f[1], warp_id + 16);
            buffer.f[2] += __shfl(outbound.f[2], warp_id + 16);
            buffer.f[3] += __shfl(outbound.f[3], warp_id + 16);

            outbound.f4 = (warp_y == 0) ? result[1].f4 :
                          (warp_y == 1) ? result[2].f4 :
                          (warp_y == 2) ? result[3].f4 :
                          result[0].f4;
            buffer.f[0] += __shfl(outbound.f[0], warp_id + 24);
            buffer.f[1] += __shfl(outbound.f[1], warp_id + 24);
            buffer.f[2] += __shfl(outbound.f[2], warp_id + 24);
            buffer.f[3] += __shfl(outbound.f[3], warp_id + 24);

            //Store result
            int idx_filter_id = filter_id + (threadIdx.x << 2);
            if(idx_filter_id < K && crs_in_bounds)
            {
                int out_index = (c * filter_channel_size) + (((rs * K) + (idx_filter_id)) >> 2);

                atomicAdd(&O[out_index].f[0], buffer.f[0]);
                atomicAdd(&O[out_index].f[1], buffer.f[1]);
                atomicAdd(&O[out_index].f[2], buffer.f[2]);
                atomicAdd(&O[out_index].f[3], buffer.f[3]);
            }
        }
    }
}
"""
    if operation == "update":
        code = header_code + update_code
    else:
        code = header_code + code

    magic = _magic64(filter_size)

    code = code % {
        "filter_size": filter_size,
        "magic_filter_size": magic[0],
        "shift_filter_size": magic[1],
        "type": _ew_types[dtype]["type"],
        "lut_code": lut_code,
        "bsum_code": bsum_code,
        "operation": operation,
        "a_name": a_name,
        "b_name": b_name,
        "filter_load_cond": filter_load_cond,
        "check_filter_cond": check_filter_cond
    }

    options = ["--use_fast_math"]
    if debug and operation == "bprop":
        options = options + ["-g", "-G"]
    module = SourceModule(code, options=options)

    kernel = module.get_function("conv_" + operation)
    kernel.prepare("ffPPPPIIIIIIIIIIIIIIIIIIIIIIIIIIIIII")
    kernel.name = "conv_" + operation
    return kernel
Exemplo n.º 11
0
def _get_bn_bprop_kernel(dtype, threads, compute_capability):

    if threads > 32:
        shr_code = "__shared__ float sPartials[THREADS * 2];"
        red_code = r"""
    sPartials[tid + THREADS*0] = grad_gamma;
    sPartials[tid + THREADS*1] = grad_beta;
    __syncthreads();

    #pragma unroll
    for (int a = THREADS >> 1; a > 32; a >>= 1)
    {
        if ( tid < a )
        {
            sPartials[tid + THREADS*0] += sPartials[tid + a + THREADS*0];
            sPartials[tid + THREADS*1] += sPartials[tid + a + THREADS*1];
        }
        __syncthreads();
    }
    if ( tid < 32 )
    {
        grad_gamma = sPartials[tid + THREADS*0] + sPartials[tid + 32 + THREADS*0];
        grad_beta  = sPartials[tid + THREADS*1] + sPartials[tid + 32 + THREADS*1];

        #pragma unroll
        for (int i = 16; i > 0; i >>= 1)
        {
            grad_gamma += __shfl_xor(grad_gamma, i);
            grad_beta  += __shfl_xor(grad_beta,  i);
        }
        sPartials[tid + THREADS*0] = grad_gamma;
        sPartials[tid + THREADS*1] = grad_beta;
    }
    __syncthreads();
    grad_gamma = sPartials[THREADS*0];
    grad_beta  = sPartials[THREADS*1];
"""
    else:
        shr_code = ""
        red_code = r"""
    #pragma unroll
    for (int i = 16; i > 0; i >>= 1)
    {
        grad_gamma += __shfl_xor(grad_gamma, i);
        grad_beta  += __shfl_xor(grad_beta,  i);
    }
"""

    code = r"""
#define THREADS %(threads)s

%(common)s
%(binary)s

__global__ void batchnorm_bprop (
    %(type)s* delta_out, float* grad_gamma_out, float* grad_beta_out,
    const %(type)s* delta_in, const %(type)s* x_in, const float* xsum_in,
    const float* xvar_in, const float* gamma_in,
    const float eps, const int N, bool binary)
{
    %(share)s

    const int tid  = threadIdx.x;
    const int bid  = blockIdx.x;
    const float rcpN = 1.0f/(float)N;
    int offset = bid * N;

    const %(type)s* x_in0 = x_in     + offset + tid;
    const %(type)s* d_in0 = delta_in + offset + tid;

    float xmean = __ldg(xsum_in  + bid) * rcpN;
    float xvar  = __ldg(xvar_in  + bid);
    float gamma = __ldg(gamma_in + bid);

    float xvar_rcp_sqrt = 1.0f / sqrtf(xvar + eps);
    float grad_gamma    = 0.0f;
    float grad_beta     = 0.0f;

    for (int i = tid; i < N; i += THREADS)
    {
        float x = %(cvt)s(__ldg(x_in0));
        x_in0 += THREADS;
        float d = %(cvt)s(__ldg(d_in0));
        d_in0 += THREADS;

        float xhat = 0.0f;
        if (binary) {
            xhat = shift_element(x - xmean, xvar_rcp_sqrt, true);
        } else {
            xhat = (x - xmean) * xvar_rcp_sqrt;
        }

        grad_gamma += xhat * d;
        grad_beta  += d;
    }
    %(red)s

    if ( tid == 0 )
    {
        *(grad_gamma_out + bid) = grad_gamma;
        *(grad_beta_out  + bid) = grad_beta;
    }

    int start = N - (THREADS*4 - tid);
    offset += start;
    const %(type)s* x_in1 = x_in     + offset;
    const %(type)s* d_in1 = delta_in + offset;
    delta_out += offset;

    for (int i = start; i >= -THREADS*3; i -= THREADS*4)
    {
        float x0 = i >= -THREADS*0 ? %(cvt)s(__ldg(x_in1 + THREADS*0)) : 0.0f;
        float x1 = i >= -THREADS*1 ? %(cvt)s(__ldg(x_in1 + THREADS*1)) : 0.0f;
        float x2 = i >= -THREADS*2 ? %(cvt)s(__ldg(x_in1 + THREADS*2)) : 0.0f;
        float x3 =                   %(cvt)s(__ldg(x_in1 + THREADS*3));

        float d0 = i >= -THREADS*0 ? %(cvt)s(__ldg(d_in1 + THREADS*0)) : 0.0f;
        float d1 = i >= -THREADS*1 ? %(cvt)s(__ldg(d_in1 + THREADS*1)) : 0.0f;
        float d2 = i >= -THREADS*2 ? %(cvt)s(__ldg(d_in1 + THREADS*2)) : 0.0f;
        float d3 =                   %(cvt)s(__ldg(d_in1 + THREADS*3));

        x_in1 -= THREADS*4;
        d_in1 -= THREADS*4;

        float xhat0 = 0.0f;
        float xhat1 = 0.0f;
        float xhat2 = 0.0f;
        float xhat3 = 0.0f;

        float xtmp0 = 0.0f;
        float xtmp1 = 0.0f;
        float xtmp2 = 0.0f;
        float xtmp3 = 0.0f;

        float delta0 = 0.0f;
        float delta1 = 0.0f;
        float delta2 = 0.0f;
        float delta3 = 0.0f;

        if (binary) {
            xhat0 = shift_element(x0 - xmean, xvar_rcp_sqrt, true);
            xhat1 = shift_element(x1 - xmean, xvar_rcp_sqrt, true);
            xhat2 = shift_element(x2 - xmean, xvar_rcp_sqrt, true);
            xhat3 = shift_element(x3 - xmean, xvar_rcp_sqrt, true);

            xtmp0 = (shift_element(xhat0, grad_gamma, true) + grad_beta) * rcpN;
            xtmp1 = (shift_element(xhat1, grad_gamma, true) + grad_beta) * rcpN;
            xtmp2 = (shift_element(xhat2, grad_gamma, true) + grad_beta) * rcpN;
            xtmp3 = (shift_element(xhat3, grad_gamma, true) + grad_beta) * rcpN;

            delta0 = shift_element(shift_element(d0 - xtmp0, gamma, true), xvar_rcp_sqrt, true);
            delta1 = shift_element(shift_element(d1 - xtmp1, gamma, true), xvar_rcp_sqrt, true);
            delta2 = shift_element(shift_element(d2 - xtmp2, gamma, true), xvar_rcp_sqrt, true);
            delta3 = shift_element(shift_element(d3 - xtmp3, gamma, true), xvar_rcp_sqrt, true);
        } else {
            xhat0 = (x0 - xmean) * xvar_rcp_sqrt;
            xhat1 = (x1 - xmean) * xvar_rcp_sqrt;
            xhat2 = (x2 - xmean) * xvar_rcp_sqrt;
            xhat3 = (x3 - xmean) * xvar_rcp_sqrt;

            xtmp0 = (xhat0 * grad_gamma + grad_beta) * rcpN;
            xtmp1 = (xhat1 * grad_gamma + grad_beta) * rcpN;
            xtmp2 = (xhat2 * grad_gamma + grad_beta) * rcpN;
            xtmp3 = (xhat3 * grad_gamma + grad_beta) * rcpN;

            delta0 = gamma * (d0 - xtmp0) * xvar_rcp_sqrt;
            delta1 = gamma * (d1 - xtmp1) * xvar_rcp_sqrt;
            delta2 = gamma * (d2 - xtmp2) * xvar_rcp_sqrt;
            delta3 = gamma * (d3 - xtmp3) * xvar_rcp_sqrt;
        }

        %(delta0_out)s
        %(delta1_out)s
        %(delta2_out)s
        %(delta3_out)s
        if (i >= -THREADS*0) *(delta_out + THREADS*0) = delta0_val;
        if (i >= -THREADS*1) *(delta_out + THREADS*1) = delta1_val;
        if (i >= -THREADS*2) *(delta_out + THREADS*2) = delta2_val;
                             *(delta_out + THREADS*3) = delta3_val;
        delta_out -= THREADS*4;
    }
}
"""
    out_code = _ew_strings["round"]["nearest"].get(dtype, "float {0} = {1};")
    common_code = _common_round["nearest"].get(dtype, "")
    if dtype == "f2":
        common_code += _common_fp16_to_fp32

    if (compute_capability[0] == 3 and compute_capability[1] < 5) or compute_capability[0] < 3:
        common_code += _common_kepler

    code = code % {
        "common"         : common_code,
        "binary"         : shift_element(),
        "share"          : shr_code,
        "red"            : red_code,
        "threads"        : threads,
        "type"           : _ew_types[dtype]["type"],
        "cvt"            : _ew_types[dtype]["cvt"],
        "delta0_out"     : out_code.format("delta0_val",     "delta0"),
        "delta1_out"     : out_code.format("delta1_val",     "delta1"),
        "delta2_out"     : out_code.format("delta2_val",     "delta2"),
        "delta3_out"     : out_code.format("delta3_val",     "delta3"),
    }
    module = SourceModule(code, options=["--use_fast_math"])
    kernel = module.get_function("batchnorm_bprop")
    kernel.prepare("PPPPPPPPfII")
    kernel.name = "batchnorm_bprop"
    return kernel
Exemplo n.º 12
0
def _get_bn_fprop_kernel(dtype, threads, compute_capability):

    if threads > 32:
        shr_code = "__shared__ float sPartials[THREADS];"
        red_code = r"""
    sPartials[tid] = xvar;
    __syncthreads();

    #pragma unroll
    for (int a = THREADS >> 1; a > 32; a >>= 1)
    {
        if ( tid < a )
            sPartials[tid] += sPartials[tid + a];
        __syncthreads();
    }
    if ( tid < 32 )
    {
        xvar = sPartials[tid] + sPartials[tid + 32];
        #pragma unroll
        for (int i = 16; i > 0; i >>= 1)
            xvar += __shfl_xor(xvar, i);

        sPartials[tid] = xvar * rcpN;
    }
    __syncthreads();
    xvar = sPartials[0];
"""
    else:
        shr_code = ""
        red_code = r"""
    #pragma unroll
    for (int i = 16; i > 0; i >>= 1)
        xvar += __shfl_xor(xvar, i);
    xvar *= rcpN;
"""

    code = r"""
#define THREADS %(threads)s

%(common)s
%(binary)s

__global__ void batchnorm_fprop (
    %(type)s* y_out, float* xvar_out, float* gmean_out, float* gvar_out,
    const %(type)s* x_in, const float* xsum_in, const float* gmean_in,
    const float* gvar_in, const float* gamma_in, const float* beta_in,
    const float eps, const float rho, const float accumbeta, const int N,
    const int relu, bool binary)
{
    %(share)s

    const int tid  = threadIdx.x;
    const int bid  = blockIdx.x;
    int offset = bid * N;

    const %(type)s* x_in0 = x_in + offset + tid;

    const float rcpN = 1.0f/(float)N;

    float xmean = __ldg(xsum_in + bid) * rcpN;

    float xvar = 0.0f;
    for (int i = tid; i < N; i += THREADS)
    {
        float x = %(cvt)s(__ldg(x_in0));
        x_in0 += THREADS;

        x -= xmean;
        if (binary) {
            xvar += shift_element(x, x, true);
        } else {
            xvar += x * x;
        }
    }
    %(red)s

    float gamma = __ldg(gamma_in + bid);
    float beta  = __ldg(beta_in  + bid);

    if ( tid == 0 )
    {
        float gmean = __ldg(gmean_in + bid);
        float gvar  = __ldg(gvar_in  + bid);

        *(xvar_out  + bid) = xvar;
        *(gmean_out + bid) = gmean * rho + (1.0f - rho) * xmean;
        *(gvar_out  + bid) = gvar  * rho + (1.0f - rho) * xvar;
    }

    float xvar_rcp_sqrt = 1.0f / sqrtf(xvar + eps);

    int start = N - (THREADS*4 - tid);
    offset += start;
    x_in   += offset;
    y_out  += offset;

    for (int i = start; i >= -THREADS*3; i -= THREADS*4)
    {
        float x0 = i >= -THREADS*0 ? %(cvt)s(__ldg(x_in + THREADS*0)) : 0.0f;
        float x1 = i >= -THREADS*1 ? %(cvt)s(__ldg(x_in + THREADS*1)) : 0.0f;
        float x2 = i >= -THREADS*2 ? %(cvt)s(__ldg(x_in + THREADS*2)) : 0.0f;
        float x3 =                   %(cvt)s(__ldg(x_in + THREADS*3));

        x_in -= THREADS*4;

        float xhat0 = 0.0f;
        float xhat1 = 0.0f;
        float xhat2 = 0.0f;
        float xhat3 = 0.0f;

        float y0 = 0.0f;
        float y1 = 0.0f;
        float y2 = 0.0f;
        float y3 = 0.0f;
        if (binary) {
            xhat0 = shift_element(x0 - xmean, xvar_rcp_sqrt, true);
            xhat1 = shift_element(x1 - xmean, xvar_rcp_sqrt, true);
            xhat2 = shift_element(x2 - xmean, xvar_rcp_sqrt, true);
            xhat3 = shift_element(x3 - xmean, xvar_rcp_sqrt, true);

            y0 = shift_element(xhat0, gamma, true) + beta;
            y1 = shift_element(xhat1, gamma, true) + beta;
            y2 = shift_element(xhat2, gamma, true) + beta;
            y3 = shift_element(xhat3, gamma, true) + beta;
        } else {
            xhat0 = (x0 - xmean) * xvar_rcp_sqrt;
            xhat1 = (x1 - xmean) * xvar_rcp_sqrt;
            xhat2 = (x2 - xmean) * xvar_rcp_sqrt;
            xhat3 = (x3 - xmean) * xvar_rcp_sqrt;

            y0 = xhat0 * gamma + beta;
            y1 = xhat1 * gamma + beta;
            y2 = xhat2 * gamma + beta;
            y3 = xhat3 * gamma + beta;
        }

        if (relu)
        {
            y0 = fmaxf(y0, 0.0f);
            y1 = fmaxf(y1, 0.0f);
            y2 = fmaxf(y2, 0.0f);
            y3 = fmaxf(y3, 0.0f);
        }

        %(y0_out)s
        %(y1_out)s
        %(y2_out)s
        %(y3_out)s
        if (accumbeta == 0.0)
        {
            if (i >= -THREADS*0) *(y_out + THREADS*0) = y0_val;
            if (i >= -THREADS*1) *(y_out + THREADS*1) = y1_val;
            if (i >= -THREADS*2) *(y_out + THREADS*2) = y2_val;
                                 *(y_out + THREADS*3) = y3_val;
        }
        else
        {
            if (i >= -THREADS*0) *(y_out + THREADS*0) = y_out[THREADS*0] * accumbeta + y0_val;
            if (i >= -THREADS*1) *(y_out + THREADS*1) = y_out[THREADS*1] * accumbeta + y1_val;
            if (i >= -THREADS*2) *(y_out + THREADS*2) = y_out[THREADS*2] * accumbeta + y2_val;
                                 *(y_out + THREADS*3) = y_out[THREADS*3] * accumbeta + y3_val;
        }
        y_out -= THREADS*4;
    }
}
"""
    out_code = _ew_strings["round"]["nearest"].get(dtype, "float {0} = {1};")
    common_code = _common_round["nearest"].get(dtype, "")
    if dtype == "f2":
        common_code += _common_fp16_to_fp32

    if (compute_capability[0] == 3
            and compute_capability[1] < 5) or compute_capability[0] < 3:
        common_code += _common_kepler

    code = code % {
        "common": common_code,
        "binary": shift_element(),
        "share": shr_code,
        "red": red_code,
        "threads": threads,
        "type": _ew_types[dtype]["type"],
        "cvt": _ew_types[dtype]["cvt"],
        "y0_out": out_code.format("y0_val", "y0"),
        "y1_out": out_code.format("y1_val", "y1"),
        "y2_out": out_code.format("y2_val", "y2"),
        "y3_out": out_code.format("y3_val", "y3"),
    }
    module = SourceModule(code, options=["--use_fast_math"])
    kernel = module.get_function("batchnorm_fprop")
    kernel.prepare("PPPPPPPPPPfffIII")
    kernel.name = "batchnorm_fprop"
    return kernel
Exemplo n.º 13
0
def _get_bn_bprop_kernel(dtype, threads, compute_capability):

    if threads > 32:
        shr_code = "__shared__ float sPartials[THREADS * 2];"
        red_code = r"""
    sPartials[tid + THREADS*0] = grad_gamma;
    sPartials[tid + THREADS*1] = grad_beta;
    __syncthreads();

    #pragma unroll
    for (int a = THREADS >> 1; a > 32; a >>= 1)
    {
        if ( tid < a )
        {
            sPartials[tid + THREADS*0] += sPartials[tid + a + THREADS*0];
            sPartials[tid + THREADS*1] += sPartials[tid + a + THREADS*1];
        }
        __syncthreads();
    }
    if ( tid < 32 )
    {
        grad_gamma = sPartials[tid + THREADS*0] + sPartials[tid + 32 + THREADS*0];
        grad_beta  = sPartials[tid + THREADS*1] + sPartials[tid + 32 + THREADS*1];

        #pragma unroll
        for (int i = 16; i > 0; i >>= 1)
        {
            grad_gamma += __shfl_xor(grad_gamma, i);
            grad_beta  += __shfl_xor(grad_beta,  i);
        }
        sPartials[tid + THREADS*0] = grad_gamma;
        sPartials[tid + THREADS*1] = grad_beta;
    }
    __syncthreads();
    grad_gamma = sPartials[THREADS*0];
    grad_beta  = sPartials[THREADS*1];
"""
    else:
        shr_code = ""
        red_code = r"""
    #pragma unroll
    for (int i = 16; i > 0; i >>= 1)
    {
        grad_gamma += __shfl_xor(grad_gamma, i);
        grad_beta  += __shfl_xor(grad_beta,  i);
    }
"""

    code = r"""
#define THREADS %(threads)s

%(common)s
%(binary)s

__global__ void batchnorm_bprop (
    %(type)s* delta_out, float* grad_gamma_out, float* grad_beta_out,
    const %(type)s* delta_in, const %(type)s* x_in, const float* xsum_in,
    const float* xvar_in, const float* gamma_in,
    const float eps, const int N, bool binary)
{
    %(share)s

    const int tid  = threadIdx.x;
    const int bid  = blockIdx.x;
    const float rcpN = 1.0f/(float)N;
    int offset = bid * N;

    const %(type)s* x_in0 = x_in     + offset + tid;
    const %(type)s* d_in0 = delta_in + offset + tid;

    float xmean = __ldg(xsum_in  + bid) * rcpN;
    float xvar  = __ldg(xvar_in  + bid);
    float gamma = __ldg(gamma_in + bid);

    float xvar_rcp_sqrt = 1.0f / sqrtf(xvar + eps);
    float grad_gamma    = 0.0f;
    float grad_beta     = 0.0f;

    for (int i = tid; i < N; i += THREADS)
    {
        float x = %(cvt)s(__ldg(x_in0));
        x_in0 += THREADS;
        float d = %(cvt)s(__ldg(d_in0));
        d_in0 += THREADS;

        float xhat = 0.0f;
        if (binary) {
            xhat = shift_element(x - xmean, xvar_rcp_sqrt, true);
        } else {
            xhat = (x - xmean) * xvar_rcp_sqrt;
        }

        grad_gamma += xhat * d;
        grad_beta  += d;
    }
    %(red)s

    if ( tid == 0 )
    {
        *(grad_gamma_out + bid) = grad_gamma;
        *(grad_beta_out  + bid) = grad_beta;
    }

    int start = N - (THREADS*4 - tid);
    offset += start;
    const %(type)s* x_in1 = x_in     + offset;
    const %(type)s* d_in1 = delta_in + offset;
    delta_out += offset;

    for (int i = start; i >= -THREADS*3; i -= THREADS*4)
    {
        float x0 = i >= -THREADS*0 ? %(cvt)s(__ldg(x_in1 + THREADS*0)) : 0.0f;
        float x1 = i >= -THREADS*1 ? %(cvt)s(__ldg(x_in1 + THREADS*1)) : 0.0f;
        float x2 = i >= -THREADS*2 ? %(cvt)s(__ldg(x_in1 + THREADS*2)) : 0.0f;
        float x3 =                   %(cvt)s(__ldg(x_in1 + THREADS*3));

        float d0 = i >= -THREADS*0 ? %(cvt)s(__ldg(d_in1 + THREADS*0)) : 0.0f;
        float d1 = i >= -THREADS*1 ? %(cvt)s(__ldg(d_in1 + THREADS*1)) : 0.0f;
        float d2 = i >= -THREADS*2 ? %(cvt)s(__ldg(d_in1 + THREADS*2)) : 0.0f;
        float d3 =                   %(cvt)s(__ldg(d_in1 + THREADS*3));

        x_in1 -= THREADS*4;
        d_in1 -= THREADS*4;

        float xhat0 = 0.0f;
        float xhat1 = 0.0f;
        float xhat2 = 0.0f;
        float xhat3 = 0.0f;

        float xtmp0 = 0.0f;
        float xtmp1 = 0.0f;
        float xtmp2 = 0.0f;
        float xtmp3 = 0.0f;

        float delta0 = 0.0f;
        float delta1 = 0.0f;
        float delta2 = 0.0f;
        float delta3 = 0.0f;

        if (binary) {
            xhat0 = shift_element(x0 - xmean, xvar_rcp_sqrt, true);
            xhat1 = shift_element(x1 - xmean, xvar_rcp_sqrt, true);
            xhat2 = shift_element(x2 - xmean, xvar_rcp_sqrt, true);
            xhat3 = shift_element(x3 - xmean, xvar_rcp_sqrt, true);

            xtmp0 = (shift_element(xhat0, grad_gamma, true) + grad_beta) * rcpN;
            xtmp1 = (shift_element(xhat1, grad_gamma, true) + grad_beta) * rcpN;
            xtmp2 = (shift_element(xhat2, grad_gamma, true) + grad_beta) * rcpN;
            xtmp3 = (shift_element(xhat3, grad_gamma, true) + grad_beta) * rcpN;

            delta0 = shift_element(shift_element(d0 - xtmp0, gamma, true), xvar_rcp_sqrt, true);
            delta1 = shift_element(shift_element(d1 - xtmp1, gamma, true), xvar_rcp_sqrt, true);
            delta2 = shift_element(shift_element(d2 - xtmp2, gamma, true), xvar_rcp_sqrt, true);
            delta3 = shift_element(shift_element(d3 - xtmp3, gamma, true), xvar_rcp_sqrt, true);
        } else {
            xhat0 = (x0 - xmean) * xvar_rcp_sqrt;
            xhat1 = (x1 - xmean) * xvar_rcp_sqrt;
            xhat2 = (x2 - xmean) * xvar_rcp_sqrt;
            xhat3 = (x3 - xmean) * xvar_rcp_sqrt;

            xtmp0 = (xhat0 * grad_gamma + grad_beta) * rcpN;
            xtmp1 = (xhat1 * grad_gamma + grad_beta) * rcpN;
            xtmp2 = (xhat2 * grad_gamma + grad_beta) * rcpN;
            xtmp3 = (xhat3 * grad_gamma + grad_beta) * rcpN;

            delta0 = gamma * (d0 - xtmp0) * xvar_rcp_sqrt;
            delta1 = gamma * (d1 - xtmp1) * xvar_rcp_sqrt;
            delta2 = gamma * (d2 - xtmp2) * xvar_rcp_sqrt;
            delta3 = gamma * (d3 - xtmp3) * xvar_rcp_sqrt;
        }

        %(delta0_out)s
        %(delta1_out)s
        %(delta2_out)s
        %(delta3_out)s
        if (i >= -THREADS*0) *(delta_out + THREADS*0) = delta0_val;
        if (i >= -THREADS*1) *(delta_out + THREADS*1) = delta1_val;
        if (i >= -THREADS*2) *(delta_out + THREADS*2) = delta2_val;
                             *(delta_out + THREADS*3) = delta3_val;
        delta_out -= THREADS*4;
    }
}
"""
    out_code = _ew_strings["round"]["nearest"].get(dtype, "float {0} = {1};")
    common_code = _common_round["nearest"].get(dtype, "")
    if dtype == "f2":
        common_code += _common_fp16_to_fp32

    if (compute_capability[0] == 3
            and compute_capability[1] < 5) or compute_capability[0] < 3:
        common_code += _common_kepler

    code = code % {
        "common": common_code,
        "binary": shift_element(),
        "share": shr_code,
        "red": red_code,
        "threads": threads,
        "type": _ew_types[dtype]["type"],
        "cvt": _ew_types[dtype]["cvt"],
        "delta0_out": out_code.format("delta0_val", "delta0"),
        "delta1_out": out_code.format("delta1_val", "delta1"),
        "delta2_out": out_code.format("delta2_val", "delta2"),
        "delta3_out": out_code.format("delta3_val", "delta3"),
    }
    module = SourceModule(code, options=["--use_fast_math"])
    kernel = module.get_function("batchnorm_bprop")
    kernel.prepare("PPPPPPPPfII")
    kernel.name = "batchnorm_bprop"
    return kernel
Exemplo n.º 14
0
def _get_conv_kernel(dtype,
                     filter_size,
                     bsum,
                     operation,
                     filter_bounds_check=False,
                     debug=False):
    """
    Builds the convolution kernel for a specified filter size.

    Arguments:
        dtype (np.dtype): The data type which the kernel will operate on.
        filter_size (int): Total number of elements per filter (R * S)
        bsum (boolean): If set to true, kernel will include code to compute
            batch sum during fprop
        operation (string): Determines which kernel to build. options follow:
            'fprop': Forward propagation of activations.
            'bprop': Backward propagation of error.
            'update': Computes gradients for filter weights based on error and inputs.
        filter_bounds_check (boolean): Checks if filter weight is in bounds when K is
            not a multiple of 32.
        debug (boolean): When set to true, kernels will be compiled with debug symbols.
    """
    assert operation in ["fprop", "bprop", "update"]
    if operation == "fprop" or operation == "update":
        lut_code = r"""
    if(tid < 32)
    {
        int rs = tid;
        int base_x, base_y;

        base_x = output_pixel_x * stride_w - padding_w;
        base_y = output_pixel_y * stride_h - padding_h;

        unsigned int mask = (1 << tid) - 1;

        while(rs < FILTER_SIZE)
        {
            int filter_x, filter_y;
            _idiv_magic32(rs, magic_s, shift_s, S, filter_y, filter_x);

            int index_x = base_x + filter_x;
            int index_y = base_y + filter_y;

            //Check if the index is valid
            int in_bounds = (index_x >= 0 && index_x < W && index_y >= 0 && index_y < H);
            unsigned int threads_in_bounds = __ballot(in_bounds);

            //Store lookup table entry
            if(in_bounds)
            {
                int2 lut_entry;
                lut_entry.x = ((index_y * W + index_x) * N) >> 2;
                lut_entry.y = (rs * K) >> 2;

                int index = lut_size_local + __popc(threads_in_bounds & mask);
                lookup_table[index] = lut_entry;
            }

            lut_size_local += __popc(threads_in_bounds);

            rs += 32;
        }
    }
"""
    elif operation == "bprop":
        lut_code = r"""
    if(tid < 32)
    {
        int rs = tid;
        int base_q, base_p;

        base_q = output_pixel_x - (S - padding_w - 1);
        base_p = output_pixel_y - (R - padding_h - 1);

        unsigned int mask = (1 << tid) - 1;

        while(rs < FILTER_SIZE)
        {
            int filter_x, filter_y;
            _idiv_magic32(rs, magic_s, shift_s, S, filter_y, filter_x);

            int index_q = base_q + filter_x;
            int index_p = base_p + filter_y;

            //Check if the index is valid
            int in_bounds = (((index_q % stride_w) | (index_p % stride_h)) == 0);
            index_q /= stride_w;
            index_p /= stride_h;
            in_bounds = in_bounds && (index_q >= 0 && index_q < W
                                      && index_p >= 0 && index_p < H);
            unsigned int threads_in_bounds = __ballot(in_bounds);

            //Store lookup table entry
            if(in_bounds)
            {
                int2 lut_entry;
                lut_entry.x = (((index_p * W) + index_q) * N) >> 2;
                lut_entry.y = (rs * K) >> 2;

                int index = lut_size_local + __popc(threads_in_bounds & mask);
                lookup_table[index] = lut_entry;
            }

            lut_size_local += __popc(threads_in_bounds);

            rs += 32;
        }
    }
"""
    if bsum:
        bsum_code = r"""
            float local_bsum = result[q_offset].f[0] + result[q_offset].f[1] +
                               result[q_offset].f[2] + result[q_offset].f[3];
            atomicAdd(&bsum[filter_id], local_bsum);
"""
    else:
        bsum_code = ""

    if operation == "update":
        a_name = "image"
        b_name = "error"
    else:
        if operation == "fprop":
            a_name = "image"
            b_name = "filter"
        elif operation == "bprop":
            a_name = "error"
            b_name = "filter"

    if filter_bounds_check:
        filter_load_cond = "int filter_load_in_bounds = (((filter_id + threadIdx.x) << 2) < K);"
        check_filter_cond = "(!filter_load_in_bounds) ? make_float4(0, 0, 0, 0) :"
    else:
        filter_load_cond = ""
        check_filter_cond = ""

    header_code = r"""
#define TILE_DIM            32
#define ITEMS_PER_THREAD    4
#define THREADS_DIM         8

#define REG_TILE_X          4
#define REG_TILE_Y          4
#define THREADS_DIM_X       8
#define THREADS_DIM_Y       8
#define SM_TILE_X           (REG_TILE_X * THREADS_DIM_X)
#define SM_TILE_Y           (REG_TILE_Y * THREADS_DIM_Y)

#define NUM_ROWS            8
#define FILTER_SIZE         %(filter_size)s
#define MAGIC_FILTER_SIZE   %(magic_filter_size)s
#define SHIFT_FILTER_SIZE   %(shift_filter_size)s

typedef union Matrix {
    %(type)s4 f4;
    %(type)s f[4];
} Matrix;

__device__ inline void _idiv_fast(int numerator, int denominator, float rcp,
                                 int& result, int& remainder)
{
    result = (int)((float)numerator * rcp);
    remainder = numerator - (result * denominator);
    result = (remainder >= denominator) ? (result + 1) : result;
    remainder = (remainder >= denominator) ? (remainder - denominator) : remainder;
}

__device__ inline void _idiv_magic(int numerator, unsigned int magic, unsigned int shift,
                                   int denominator, int& result, int& remainder)
{
    if(magic == 1)
    {
        result = numerator >> shift;
    }
    else
    {
        unsigned long long res64 = numerator * (unsigned long long)magic;
        result = ((int)(res64 >> 32) >> shift);
    }
    remainder = numerator - (result * denominator);
}

__device__ inline void _idiv_magic32(int numerator, unsigned int magic, unsigned int shift,
                                     int denominator, int& result, int& remainder)
{
    if(magic == 1)
    {
        result = numerator >> shift;
    }
    else
    {
        result = ((numerator * magic) >> shift);
    }
    remainder = numerator - (result * denominator);
}

//Note: N and K must be multiples of 4
//blockIdx.x is gemm tile id (K dimension) and output pixel id
//blockIdx.y is gemm tile id (N dimension)
//threadIdx.x is gemm tile offset (K dimension)
//threadIdx.y is gemm tile offset (N dimension)
__global__ void conv_%(operation)s(
                           %(type)s alpha, %(type)s beta,
                           Matrix *I, Matrix *F, Matrix *O, float* bsum,
                           int C, int D, int H, int W, int N,
                           int T, int R, int S, int K,
                           int M, int P, int Q,
                           int stride_w, int stride_h, int padding_w, int padding_h,
                           int input_channel_size, int filter_channel_size,
                           int output_filter_size,
                           int output_pixels, int grid_p, int grid_q,
                           unsigned int magic_pq, unsigned int shift_pq,
                           unsigned int magic_q, unsigned int shift_q,
                           unsigned int magic_s, unsigned int shift_s)

"""
    code = r"""
{
    __shared__ int2 lookup_table[FILTER_SIZE];
    __shared__ int lut_size;
    __shared__ Matrix %(a_name)s_data[NUM_ROWS][THREADS_DIM_X];
    __shared__ Matrix %(b_name)s_data[NUM_ROWS][THREADS_DIM_Y];

    int lut_size_local = 0;

    //TODO: Use square access pattern to image data to increase cache hits
    int output_pixel, image_id;
    _idiv_magic(blockIdx.x, magic_pq, shift_pq, output_pixels, image_id, output_pixel);
    image_id = (image_id * blockDim.x);

    //Zig zag along x axis to increase cache hits
    int temp_x, temp_y;
    _idiv_magic(output_pixel, magic_q, shift_q, Q, temp_y, temp_x);
    int output_pixel_x = (temp_y & 1) ? (Q - temp_x - 1) : temp_x;
    int output_pixel_y = temp_y;
    output_pixel = output_pixel_x + (output_pixel_y * Q);

    int filter_id = blockIdx.y * blockDim.y;
    int tid = threadIdx.x + threadIdx.y * blockDim.x;

    //Offset buffers based on thread id
    I = &(I[image_id  + threadIdx.x]);
    F = &(F[filter_id + threadIdx.x]);

    %(filter_load_cond)s

    //Compute lookup table for filter/image data
%(lut_code)s

    if(tid == 0)
    {
        lut_size = lut_size_local;
    }

    __syncthreads();

    lut_size_local = lut_size;
    Matrix result[REG_TILE_Y] = {0};
    output_pixel = (output_pixel * N) >> 2;
    if(lut_size_local > 0)
    {
        //Evaluate gemm with outer product dimensions N, K and inner product CRS
        int CRS = lut_size_local * C;

        //Compute magic numbers for division by lut_size
        float reciprocal = 1.0f / (float)lut_size_local;

        //Initialize shared mem for first block
        int crs = CRS %% NUM_ROWS;
        crs = (crs == 0) ? 8 : crs;

        int c, rs;
        _idiv_fast(CRS - threadIdx.y - 1, lut_size_local, reciprocal, c, rs);

        int2 lut_entry = ((threadIdx.y & 7) >= crs) ? make_int2(0, 0) : lookup_table[rs];
        %(a_name)s_data[threadIdx.y][threadIdx.x].f4 =
            ((threadIdx.y & 7) >= crs) ? make_float4(0, 0, 0, 0) :
            I[(c * input_channel_size)  + lut_entry.x].f4;
        %(b_name)s_data[threadIdx.y][threadIdx.x].f4 = %(check_filter_cond)s
            ((threadIdx.y & 7) >= crs) ? make_float4(0, 0, 0, 0) :
            F[(c * filter_channel_size) + lut_entry.y].f4;

        //Iterate over entire filter
        for(crs = CRS - crs - 1; crs > 0; crs -= NUM_ROWS)
        {
            __syncthreads();

            #pragma unroll
            for(int i = 0; i < NUM_ROWS; i++)
            {
                Matrix load_row;
                Matrix load_col;

                load_row.f4 = %(a_name)s_data[i][threadIdx.x].f4;
                load_col.f4 = %(b_name)s_data[i][threadIdx.y].f4;

                //Accumulate product
                #pragma unroll
                for(int q_offset = 0; q_offset < REG_TILE_Y; q_offset++)
                {
                    #pragma unroll
                    for(int p_offset = 0; p_offset < REG_TILE_X; p_offset++)
                    {
                        result[q_offset].f[p_offset] += (load_row.f[p_offset] *
                                                         load_col.f[q_offset]);
                    }
                }
            }

            __syncthreads();

            //Load new image data and filter weights
            _idiv_fast(crs - threadIdx.y, lut_size_local, reciprocal, c, rs);

            lut_entry = lookup_table[rs];
            %(a_name)s_data[threadIdx.y][threadIdx.x].f4 =
                I[(c * input_channel_size)  + lut_entry.x].f4;
            %(b_name)s_data[threadIdx.y][threadIdx.x].f4 =
                %(check_filter_cond)s F[(c * filter_channel_size) + lut_entry.y].f4;
        }

        __syncthreads();

        //Accumulate product for last iteration
        #pragma unroll
        for(int i = 0; i < NUM_ROWS; i++)
        {
            Matrix load_row;
            Matrix load_col;

            load_row.f4 = %(a_name)s_data[i][threadIdx.x].f4;
            load_col.f4 = %(b_name)s_data[i][threadIdx.y].f4;

            //Accumulate product
            #pragma unroll
            for(int q_offset = 0; q_offset < REG_TILE_Y; q_offset++)
            {
                #pragma unroll
                for(int p_offset = 0; p_offset < REG_TILE_X; p_offset++)
                {
                    result[q_offset].f[p_offset] += (load_row.f[p_offset] * load_col.f[q_offset]);
                }
            }
        }
    }

    //Store result
    filter_id = (filter_id + threadIdx.y) << 2;
    if(filter_id < K)
    {
        image_id += threadIdx.x;

        #pragma unroll
        for(int q_offset = 0; q_offset < 4; q_offset++)
        {
            if(filter_id < K)
            {
                int out_index = (filter_id * output_filter_size) + output_pixel + image_id;
                %(bsum_code)s

                Matrix cur_value = {0};
                if(beta > 0.0f)
                {
                    cur_value.f4 = O[out_index].f4;
                }

                result[q_offset].f[0] = (result[q_offset].f[0] * alpha) + (cur_value.f[0] * beta);
                result[q_offset].f[1] = (result[q_offset].f[1] * alpha) + (cur_value.f[1] * beta);
                result[q_offset].f[2] = (result[q_offset].f[2] * alpha) + (cur_value.f[2] * beta);
                result[q_offset].f[3] = (result[q_offset].f[3] * alpha) + (cur_value.f[3] * beta);

                O[out_index].f4 = result[q_offset].f4;
            }
            filter_id++;
        }
    }
}
"""

    update_code = r"""
{
    __shared__ Matrix %(a_name)s_data[TILE_DIM / 4][THREADS_DIM * 4 + 4];
    __shared__ Matrix %(b_name)s_data[TILE_DIM / 4][THREADS_DIM * 4 + 4];

    //TODO: Use square access pattern to image data to increase cache hits
    int output_pixel, filter_id;
    _idiv_magic(blockIdx.x, magic_pq, shift_pq, output_pixels, filter_id, output_pixel);
    filter_id = filter_id * TILE_DIM;
    int load_filter_id = filter_id + threadIdx.y;

    int filter_pixel_id = blockIdx.y * TILE_DIM;

    //TODO: Zig zag along x axis to increase cache hits
    int output_pixel_x, output_pixel_y;
    _idiv_magic(output_pixel, magic_q, shift_q, grid_q, output_pixel_y, output_pixel_x);

    //Compute input image and filter offsets for this pixel
    int c, rs;
    int crs = filter_pixel_id + threadIdx.y;
    _idiv_magic(crs, MAGIC_FILTER_SIZE, SHIFT_FILTER_SIZE, FILTER_SIZE, c, rs);

    int filter_x, filter_y;
    _idiv_magic32(rs, magic_s, shift_s, S, filter_y, filter_x);

    int output_pixel_x_save = output_pixel_x;
    for(; output_pixel_y < P; output_pixel_y += grid_p)
    {
        for(output_pixel_x = output_pixel_x_save; output_pixel_x < Q; output_pixel_x += grid_q)
        {
            int base_x = output_pixel_x * stride_w - padding_w + filter_x;
            int base_y = output_pixel_y * stride_h - padding_h + filter_y;
            int crs_in_bounds = (c < C) && (base_x >= 0) && (base_x < W) &&
                                (base_y >= 0) && (base_y < H);
            int input_pixel = W * base_y + base_x;
            output_pixel = output_pixel_x + (Q * output_pixel_y);

            //Pre-multiply offset to simplify indexing
            input_pixel = (input_pixel * N) >> 2;
            output_pixel = (output_pixel * N) >> 2;

            //Evaluate gemm with outer product dimensions N, K and inner product CRS
            Matrix result[ITEMS_PER_THREAD] = {0};

            //Load image data and transpose into shared mem
            //TODO: pad shared memory to avoid bank conflicts
            Matrix buffer;
            buffer.f4 = crs_in_bounds ?
                        I[(c * input_channel_size) + input_pixel + threadIdx.x].f4 :
                        make_float4(0, 0, 0, 0);
            %(a_name)s_data[threadIdx.x][ 0 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[0];
            %(a_name)s_data[threadIdx.x][ 8 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[1];
            %(a_name)s_data[threadIdx.x][16 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[2];
            %(a_name)s_data[threadIdx.x][24 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[3];

            //Load error data and transpose into shared mem
            buffer.f4 = (load_filter_id < K) ?
                        F[(load_filter_id * output_filter_size) + output_pixel + threadIdx.x].f4 :
                        make_float4(0, 0, 0, 0);
            %(b_name)s_data[threadIdx.x][ 0 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[0];
            %(b_name)s_data[threadIdx.x][ 8 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[1];
            %(b_name)s_data[threadIdx.x][16 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[2];
            %(b_name)s_data[threadIdx.x][24 | threadIdx.y >> 2].f[threadIdx.y & 3] = buffer.f[3];

            //Iterate over entire minibatch
            for(int n = threadIdx.x + (TILE_DIM >> 2); n < (N >> 2); n += (TILE_DIM >> 2))
            {
                __syncthreads();

                #pragma unroll
                for(int i = 0; i < (TILE_DIM >> 2); i++)
                {
                    Matrix row_image;
                    Matrix row_error;

                    row_image.f4 =
                        %(a_name)s_data[i][((threadIdx.y & 3) << 3) | threadIdx.y >> 2].f4;
                    row_error.f4 =
                        %(b_name)s_data[i][((threadIdx.y & 3) << 3) | threadIdx.x].f4;

                    //Accumulate product
                    #pragma unroll
                    for(int q_offset = 0; q_offset < ITEMS_PER_THREAD; q_offset++)
                    {
                        #pragma unroll
                        for(int p_offset = 0; p_offset < ITEMS_PER_THREAD; p_offset++)
                        {
                            result[p_offset].f[q_offset] +=
                                (row_image.f[p_offset] * row_error.f[q_offset]);
                        }
                    }
                }

                __syncthreads();

                //Load image data and transpose into shared mem
                buffer.f4 = crs_in_bounds ?
                    I[(c * input_channel_size) + input_pixel + n].f4 :
                    make_float4(0, 0, 0, 0);
                %(a_name)s_data[threadIdx.x][ 0 | threadIdx.y >> 2].f[threadIdx.y & 3] =
                    buffer.f[0];
                %(a_name)s_data[threadIdx.x][ 8 | threadIdx.y >> 2].f[threadIdx.y & 3] =
                    buffer.f[1];
                %(a_name)s_data[threadIdx.x][16 | threadIdx.y >> 2].f[threadIdx.y & 3] =
                    buffer.f[2];
                %(a_name)s_data[threadIdx.x][24 | threadIdx.y >> 2].f[threadIdx.y & 3] =
                    buffer.f[3];

                //Load error data and transpose into shared mem
                buffer.f4 = (load_filter_id < K) ?
                    F[(load_filter_id * output_filter_size) + output_pixel + n].f4 :
                    make_float4(0, 0, 0, 0);
                %(b_name)s_data[threadIdx.x][ 0 | threadIdx.y >> 2].f[threadIdx.y & 3] =
                    buffer.f[0];
                %(b_name)s_data[threadIdx.x][ 8 | threadIdx.y >> 2].f[threadIdx.y & 3] =
                    buffer.f[1];
                %(b_name)s_data[threadIdx.x][16 | threadIdx.y >> 2].f[threadIdx.y & 3] =
                    buffer.f[2];
                %(b_name)s_data[threadIdx.x][24 | threadIdx.y >> 2].f[threadIdx.y & 3] =
                    buffer.f[3];
            }

            __syncthreads();

            //Accumulate product for last iteration
            #pragma unroll
            for(int i = 0; i < (TILE_DIM >> 2); i++)
            {
                Matrix row_image;
                Matrix row_error;

                row_image.f4 = %(a_name)s_data[i][((threadIdx.y & 3) << 3) | threadIdx.y >> 2].f4;
                row_error.f4 = %(b_name)s_data[i][((threadIdx.y & 3) << 3) | threadIdx.x].f4;

                //Accumulate product
                #pragma unroll
                for(int q_offset = 0; q_offset < ITEMS_PER_THREAD; q_offset++)
                {
                    #pragma unroll
                    for(int p_offset = 0; p_offset < ITEMS_PER_THREAD; p_offset++)
                    {
                        result[p_offset].f[q_offset] +=
                            (row_image.f[p_offset] * row_error.f[q_offset]);
                    }
                }
            }

            //Reduce result between threads in warp
            Matrix outbound;
            int warp_y = threadIdx.y & 3;
            int warp_id = threadIdx.x + (threadIdx.y << 3);
            buffer.f4 = (warp_y == 0) ? result[0].f4 :
                        (warp_y == 1) ? result[1].f4 :
                        (warp_y == 2) ? result[2].f4 :
                        result[3].f4;

            outbound.f4 = (warp_y == 0) ? result[3].f4 :
                          (warp_y == 1) ? result[0].f4 :
                          (warp_y == 2) ? result[1].f4 :
                          result[2].f4;
            buffer.f[0] += __shfl(outbound.f[0], warp_id + 8);
            buffer.f[1] += __shfl(outbound.f[1], warp_id + 8);
            buffer.f[2] += __shfl(outbound.f[2], warp_id + 8);
            buffer.f[3] += __shfl(outbound.f[3], warp_id + 8);

            outbound.f4 = (warp_y == 0) ? result[2].f4 :
                          (warp_y == 1) ? result[3].f4 :
                          (warp_y == 2) ? result[0].f4 :
                          result[1].f4;
            buffer.f[0] += __shfl(outbound.f[0], warp_id + 16);
            buffer.f[1] += __shfl(outbound.f[1], warp_id + 16);
            buffer.f[2] += __shfl(outbound.f[2], warp_id + 16);
            buffer.f[3] += __shfl(outbound.f[3], warp_id + 16);

            outbound.f4 = (warp_y == 0) ? result[1].f4 :
                          (warp_y == 1) ? result[2].f4 :
                          (warp_y == 2) ? result[3].f4 :
                          result[0].f4;
            buffer.f[0] += __shfl(outbound.f[0], warp_id + 24);
            buffer.f[1] += __shfl(outbound.f[1], warp_id + 24);
            buffer.f[2] += __shfl(outbound.f[2], warp_id + 24);
            buffer.f[3] += __shfl(outbound.f[3], warp_id + 24);

            //Store result
            int idx_filter_id = filter_id + (threadIdx.x << 2);
            if(idx_filter_id < K && crs_in_bounds)
            {
                int out_index = (c * filter_channel_size) + (((rs * K) + (idx_filter_id)) >> 2);

                atomicAdd(&O[out_index].f[0], buffer.f[0]);
                atomicAdd(&O[out_index].f[1], buffer.f[1]);
                atomicAdd(&O[out_index].f[2], buffer.f[2]);
                atomicAdd(&O[out_index].f[3], buffer.f[3]);
            }
        }
    }
}
"""
    if operation == "update":
        code = header_code + update_code
    else:
        code = header_code + code

    magic = _magic64(filter_size)

    code = code % {
        "filter_size": filter_size,
        "magic_filter_size": magic[0],
        "shift_filter_size": magic[1],
        "type": _ew_types[dtype]["type"],
        "lut_code": lut_code,
        "bsum_code": bsum_code,
        "operation": operation,
        "a_name": a_name,
        "b_name": b_name,
        "filter_load_cond": filter_load_cond,
        "check_filter_cond": check_filter_cond
    }

    options = ["--use_fast_math"]
    if debug and operation == "bprop":
        options = options + ["-g", "-G"]
    module = SourceModule(code, options=options)

    kernel = module.get_function("conv_" + operation)
    kernel.prepare("ffPPPPIIIIIIIIIIIIIIIIIIIIIIIIIIII")
    kernel.name = "conv_" + operation
    return kernel
Exemplo n.º 15
0
def _get_lut_bprop_kernel(dtype, deterministic=False):
    """
    Builds the bprop kernel for lookup table layers based on templated code.
    If the deterministic version is requested, an index buffer must be passed
    as an argument. This index buffer re-orders items in the input tensor
    so that word_ids are sorted. This is required since we need to be sure that
    each thread only updates weights for one word id.

    Arguments:
        dtype (np.dtype): The data which the kernel will operate on.
        deterministic (boolean): Builds the deterministic kernel when this is
            set to True.
    """
    if not deterministic:
        code = r"""
__global__ void lut_bprop(
    int* inputs, %(type)s* dW, %(type)s* errors, const int nin,
    const int embedding_dim, const int vocab_size, const int pad_idx)
{
    const int tid  = threadIdx.x;
    const int bid  = blockIdx.x;

    int word_id = inputs[bid];
    int error_row = bid * embedding_dim;
    int output_row = word_id * embedding_dim;

    if(word_id != pad_idx)
    {
        for(int i = tid; i < embedding_dim; i += blockDim.x)
        {
            atomicAdd(&dW[output_row + i], errors[error_row + i]);
        }
    }
}
"""

        code = code % {
            "type": _ew_types[dtype]["type"]
        }

        module = SourceModule(code, options=["--use_fast_math"])
        kernel = module.get_function("lut_bprop")
        kernel.prepare("PPPIIIi")
    else:
        code = r"""
__global__ void lut_bprop(
    int* inputs, int* index_buffer, %(type)s* dW, %(type)s* errors,
    const int nin, const int embedding_dim, const int vocab_size,
    const int pad_idx)
{
    const int tid  = threadIdx.x;
    const int bid  = blockIdx.x;

    int index_position = bid;
    int index = index_buffer[index_position];
    int word_id = inputs[index];

    if((bid == 0 || word_id != inputs[index_buffer[bid - 1]]) && word_id != pad_idx)
    {
        int output_row = word_id * embedding_dim;

        do {
            int error_row = index * embedding_dim;

            for(int i = tid; i < embedding_dim; i += blockDim.x)
            {
                dW[output_row + i] += errors[error_row + i];
            }

            index_position++;
            if(index_position == gridDim.x)
            {
                break;
            }
            index = index_buffer[index_position];
        } while(inputs[index] == word_id);
    }
}
"""

        code = code % {
            "type": _ew_types[dtype]["type"]
        }

        module = SourceModule(code, options=["--use_fast_math"])
        kernel = module.get_function("lut_bprop")
        kernel.prepare("PPPPIIIi")

    kernel.name = "lut_bprop"
    return kernel
Exemplo n.º 16
0
def _get_nms_kernel():

    code = r"""
#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
int const threadsPerBlock = sizeof(unsigned int) * 8;

__device__ inline float devIoU(float const * const a, float const * const b) {
  float left = max(a[0], b[0]), right = min(a[2], b[2]);
  float top = max(a[1], b[1]), bottom = min(a[3], b[3]);
  float width = max(right - left + 1, 0.f), height = max(bottom - top + 1, 0.f);
  float interS = width * height;
  float Sa = (a[2] - a[0] + 1) * (a[3] - a[1] + 1);
  float Sb = (b[2] - b[0] + 1) * (b[3] - b[1] + 1);
  return interS / (Sa + Sb - interS);
}

__global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh,
                           const float *dev_boxes, unsigned int *dev_mask) {
  const int row_start = blockIdx.y;
  const int col_start = blockIdx.x;

  // if (row_start > col_start) return;

  const int row_size =
        min(n_boxes - row_start * threadsPerBlock, threadsPerBlock);
  const int col_size =
        min(n_boxes - col_start * threadsPerBlock, threadsPerBlock);

  __shared__ float block_boxes[threadsPerBlock * 5];
  if (threadIdx.x < col_size) {
    block_boxes[threadIdx.x * 5 + 0] =
        dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 0];
    block_boxes[threadIdx.x * 5 + 1] =
        dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 1];
    block_boxes[threadIdx.x * 5 + 2] =
        dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 2];
    block_boxes[threadIdx.x * 5 + 3] =
        dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 3];
    block_boxes[threadIdx.x * 5 + 4] =
        dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 4];
  }
  __syncthreads();

  if (threadIdx.x < row_size) {
    const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x;
    const float *cur_box = dev_boxes + cur_box_idx * 5;
    int i = 0;
    unsigned int t = 0;
    int start = 0;
    if (row_start == col_start) {
      start = threadIdx.x + 1;
    }
    for (i = start; i < col_size; i++) {
      if (devIoU(cur_box, block_boxes + i * 5) > nms_overlap_thresh) {
        t |= 1UL << i;
      }
    }
    const int col_blocks = DIVUP(n_boxes, threadsPerBlock);
    dev_mask[cur_box_idx * col_blocks + col_start] = t;
  }
}

"""

    module = SourceModule(code)
    kernel = module.get_function("nms_kernel")
    sig = "1I 1f 2P"
    kernel.prepare(sig)
    return kernel
Exemplo n.º 17
0
def _get_hist_kernel(dtype_str, nbins, offset):
    """
    Build a kernel to compute a 64 bin histogram.

    Use templating to generate a customized kernel depending on the input data type.

    Memoized to avoid compiling the same kernel twice.
    """
    type_str = _ew_types[dtype_str[1:]]
    from string import Template
    code = Template(_common_fp16_to_fp32 + r"""

#define MAX(a,b) (a > b ? a : b)
#define MIN(a,b) (a < b ? a : b)

__global__ void kernel_histo (
    int* d_hist, const $in_type* a1_in,
    int strides, int size)
{
    const int tid = threadIdx.x;
    const int bid = blockIdx.x;

    __shared__ int s[$nbins];
    if(tid < $nbins){
        s[tid] = 0;
    }

    if(bid == 0 && tid < $nbins){
        d_hist[tid] = 0;
    }

    for (int i = tid + blockDim.x*bid; i < size; i += strides)
    {
        float a1 = $convert_to_float(__ldg(a1_in + i));

        float absval = fabs(a1);

        float logabs = round(log2f(absval));

        int bin = MIN($nbins-1, MAX(0, logabs-($offset)));

        atomicAdd(&s[bin], 1);

    }

    __syncthreads();

    if(tid < $nbins){
        atomicAdd(&d_hist[tid], s[tid]);
    }
}
""")

    module = SourceModule(code.substitute(in_type=type_str['type'],
                                          convert_to_float=type_str['cvt'],
                                          nbins=nbins,
                                          offset=offset),
                          options=[])
    kernel = module.get_function("kernel_histo")
    kernel.prepare("PPII")
    return kernel
Exemplo n.º 18
0
def _get_fprop_roipooling(clss):

    code = r"""
#define FLT_MAX 3.402823466E+38F

__global__ void fprop_roipooling(const int nthreads,
    const int num_rois, const int img_count,
    const int channels, const int height, const int width,
    const int pooled_height, const int pooled_width,
    const float* bottom_data, const float* bottom_rois, float* top_data,
    int* argmax_data, const float spatial_scale) {
    for (int index = blockIdx.x * blockDim.x + threadIdx.x; \
        index < (nthreads); index += blockDim.x * gridDim.x){
        // (c, ph, pw, n) is an element in the pooled output
        int n = index % num_rois;
        int pw = (index / num_rois) % pooled_width;
        int ph = (index / num_rois / pooled_width) % pooled_height;
        int c = index / num_rois / pooled_width / pooled_height;

        bottom_rois += n * 5;
        int roi_batch_ind = bottom_rois[0];
        int roi_start_w = round(bottom_rois[1] * spatial_scale);
        int roi_start_h = round(bottom_rois[2] * spatial_scale);
        int roi_end_w = round(bottom_rois[3] * spatial_scale);
        int roi_end_h = round(bottom_rois[4] * spatial_scale);

        // Force malformed ROIs to be 1x1
        int roi_width = max(roi_end_w - roi_start_w + 1, 1);
        int roi_height = max(roi_end_h - roi_start_h + 1, 1);
        float bin_size_h = static_cast<float>(roi_height)
                           / static_cast<float>(pooled_height);
        float bin_size_w = static_cast<float>(roi_width)
                           / static_cast<float>(pooled_width);

        int hstart = static_cast<int>(floor(static_cast<float>(ph)
                                            * bin_size_h));
        int wstart = static_cast<int>(floor(static_cast<float>(pw)
                                            * bin_size_w));
        int hend = static_cast<int>(ceil(static_cast<float>(ph + 1)
                                         * bin_size_h));
        int wend = static_cast<int>(ceil(static_cast<float>(pw + 1)
                                         * bin_size_w));

        // Add roi offsets and clip to input boundaries
        hstart = min(max(hstart + roi_start_h, 0), height);
        hend = min(max(hend + roi_start_h, 0), height);
        wstart = min(max(wstart + roi_start_w, 0), width);
        wend = min(max(wend + roi_start_w, 0), width);
        bool is_empty = (hend <= hstart) || (wend <= wstart);

        // Define an empty pooling region to be zero
        float maxval = is_empty ? 0 : -FLT_MAX;
        // If nothing is pooled, argmax = -1 causes nothing to be backprop'd
        int maxidx = -1;

        bottom_data += c * height * width * img_count;

        for (int h = hstart; h < hend; ++h) {
          for (int w = wstart; w < wend; ++w) {
            int bottom_index = h * width * img_count + w * img_count + roi_batch_ind;
            if (bottom_data[bottom_index] > maxval) {
              maxval = bottom_data[bottom_index];
              maxidx = bottom_index;
            }
          }
        }
        top_data[index] = maxval;
        argmax_data[index] = maxidx;
        // Notice the maxidx (from bottom_index) is relative to the dimension
        // (h, w, img_count) of the feature map, so max value is HWN
    }
}

"""

    module = SourceModule(code)
    kernel = module.get_function("fprop_roipooling")
    sig = "8I 4P 1f"
    kernel.prepare(sig)
    return kernel
Exemplo n.º 19
0
def _get_compound_kernel(type_args, compute_capability):
    """
    generate compound kernel for the optree from type_args
    """

    # from the stack, rebuild a mutable tree
    tree = _build_tree(type_args)
    # _print_tree(tree)
    # exit()

    # split all reductions and post reduction scalar operations out of the tree
    # sub-trees are converted to stacks and pushed onto stages list
    stages = _split_stages(tree)
    # _print_tree(tree)
    # exit()

    # set the final stage type to type of output (scalar or elementwise)
    last_stage = "red_out" if tree[1] == 1 else "ew_out"
    # convert the remainder of tree to stack
    stages.append((last_stage, _post_order(tree)))

    # for stage, stage_data in enumerate(stages):
    #     print stage_data[0], stage
    #     for s in stage_data[1]: print s
    #     print
    # exit()

    stack = list()
    placeholders = list()
    stage_out_reg = dict()
    arg_dict = dict()
    array_ids = set()
    fp16In = False
    rand_init = False
    rand_func = False
    threads = type_args[-1][3]
    template = _ew_template
    template_vals = {
        "threads": threads,
        "name": _get_kernel_name(),
        "common": list(),
        "inits": list(),
        "finish": list(),
    }

    for stage, stage_data in enumerate(stages):

        stage_type, stage_stack = stage_data
        new_placeholders = list()

        # build out the template as we process stages
        if stage_type == "reduction":

            new_placeholders.append("loads%d" % stage)
            new_placeholders.append("ops%d" % stage)
            new_placeholders.append("shfl_red%d" % stage)
            template += _stage_template["loop"].format(stage)
            if threads > 32:
                new_placeholders.append("var_red%d" % stage)
                new_placeholders.append("share1_red%d" % stage)
                new_placeholders.append("share2_red%d" % stage)
                template += _stage_template["red"].format(stage)
            else:
                template += _stage_template["red32"].format(stage)

        elif stage_type == "scalar":

            new_placeholders.append("ops%d" % stage)
            template += _stage_template["red_ops"].format(stage)

        elif stage_type == "red_out":

            new_placeholders.append("ops%d" % stage)
            template += _stage_template["red_out"].format(stage)

        else:  # ew_out

            new_placeholders.append("loads%d" % stage)
            new_placeholders.append("ops%d" % stage)
            template += _stage_template["loop"].format(stage)

        for key in new_placeholders:
            template_vals[key] = []
        placeholders.extend(new_placeholders)

        for arg_i, arg in enumerate(stage_stack):

            arg_type, arg_id = arg[0:2]

            # Array operands
            if arg_type is ng.GPUTensor:

                dtype, take_axis = arg[2:4]

                is_out_tensor = True if stage == len(
                    stages) - 1 and arg_i == 0 else False

                # first arg is output array, don't put on stack
                if is_out_tensor:
                    out_dtype = dtype
                    out_take = take_axis
                else:
                    stack.append("a%d" % arg_id)

                # 0: arg_id, 1: stage, 2: type, 3: cvt
                ew_dtype = _ew_types[dtype]
                fmt = (arg_id, stage, ew_dtype["type"], ew_dtype["cvt"])

                # First time we see a tensor initialize everything
                if arg_id not in array_ids:

                    array_ids.add(arg_id)
                    array_ids.add((arg_id, stage))

                    sig = "Pii"
                    if take_axis > 0:
                        sig += "P"

                    # output tensor
                    if is_out_tensor:
                        ew_out = _ew_strings["out%d" % take_axis]
                        arguments = ew_out["arguments"].format(*fmt)
                        template_vals["inits"].append(
                            ew_out["inits"].format(*fmt))
                    # input tensors
                    else:
                        ew_in = _ew_strings["in%d" % take_axis]
                        loads = "loads%d" % stage
                        arguments = ew_in["arguments"].format(*fmt)
                        template_vals["inits"].append(
                            ew_in["inits"].format(*fmt))
                        template_vals[loads].append(
                            ew_in["loads"].format(*fmt))

                    if dtype == 'f2' and not fp16In:
                        template_vals["common"].append(_common_fp16_to_fp32)
                        fp16In = True

                    arg_dict[arg] = (sig, arguments)

                # Subsequent times we see a tensor just initialize inits and
                # loads
                elif (arg_id, stage) not in array_ids:
                    array_ids.add((arg_id, stage))
                    ew_in = _ew_strings["in%d" % take_axis]
                    loads = "loads%d" % stage
                    template_vals["inits"].append(ew_in["inits"].format(*fmt))
                    template_vals[loads].append(ew_in["loads"].format(*fmt))

            # Constant operands
            elif arg_type is float:

                stack.append("c%d" % arg_id)
                if arg not in arg_dict:
                    arg_dict[arg] = (
                        "f", _ew_strings["const"]["arguments"].format(arg_id))

            # Operations (arg_type = op_name)
            else:

                if arg_type == "assign":

                    ops = "ops%d" % stage

                    # loop end condition for last stage
                    sig = "i"
                    arguments = ["const int n%d" % stage]

                    # rounding mode
                    if arg[2]:
                        mode = "random"
                        sig += "i"
                        arguments.append("const int mantissa_bits")
                        if not rand_init:
                            rand_init = _init_rand(template_vals)
                        template_vals["inits"].append(_init_rand_round_func)
                    else:
                        mode = "nearest"

                    arg_dict[arg] = (sig, ", ".join(arguments))

                    out_val = stack.pop()
                    # if the last stack value came from an argmax/min just do
                    # implicit type conversion
                    if out_val[0] == "i" and out_dtype[0] in "iu":
                        ew_round = None
                    else:
                        ew_round = _ew_strings["round"][
                            mode].get(out_dtype, None)
                        ew_common = _common_round[mode].get(out_dtype, None)
                        if ew_common:
                            template_vals["common"].append(ew_common)

                    if ew_round:
                        round_val = "r%d" % arg_id
                        template_vals[ops].append(
                            ew_round.format(round_val, out_val))
                    else:
                        round_val = out_val

                    template_vals[ops].append(
                        _ew_strings["out%d" % out_take]["output"].format(round_val))

                elif arg in stage_out_reg:

                    stack.append(stage_out_reg[arg])

                elif arg_type in _float_ops:

                    if len(template_vals["name"]) < 16:
                        template_vals["name"].append(arg_type)

                    ops = "ops%d" % stage

                    (num_ops, op_code) = _float_ops[arg_type]

                    if arg_type == "rand":
                        if not rand_init:
                            rand_init = _init_rand(template_vals)
                        if not rand_func:
                            template_vals["common"].append(_common_frand)
                            rand_func = True

                    op_list = ["r%d" % arg_id]

                    # build the operands from the stack
                    for i in range(num_ops):
                        op_list.append(stack.pop())

                    if arg_type == "onehot":

                        hot_axis = arg[2]
                        test_val = "i" if hot_axis else "bid"

                        ew_in = _ew_strings[arg_type + native_str(hot_axis)]
                        loads = "loads%d" % stage
                        template_vals["inits"].append(
                            ew_in["inits"].format(arg_id))
                        template_vals[loads].append(
                            ew_in["loads"].format(arg_id))
                        op_list.append("onehot%d" % arg_id)
                        op_list.append(test_val)

                        arg_dict[arg] = (
                            "P", ew_in["arguments"].format(arg_id))

                    template_vals[ops].append(op_code.format(*op_list))

                    # if this is the last op on the current stack, store its register stage
                    # in the stage output dict
                    if arg_i == len(stage_stack) - 1:
                        stage_out_reg[arg] = op_list[0]
                    # otherwise push the reg onto the stack as normal
                    else:
                        stack.append(op_list[0])

                elif arg_type in _reduction_ops:

                    if len(template_vals["name"]) < 16:
                        template_vals["name"].append(arg_type)

                    # loop end condition for current stage
                    # add regardless of duplicate reduction stage
                    arg_dict[arg] = ("i", "const int n%d" % stage)

                    # avoid float conversion for argmax/min
                    reg = "i" if "arg" == arg_type[0:3] else "r"

                    ops = "ops%d" % stage
                    shfl_red = "shfl_red%d" % stage
                    red_arg = "%s%d" % (reg, arg_id)
                    red_strings = _reduction_ops[arg_type]
                    stack_arg = stack.pop()

                    template_vals["inits"].append(
                        red_strings["inits"].format(red_arg))
                    template_vals[ops].append(
                        red_strings["ops"].format(red_arg, stack_arg))
                    template_vals[shfl_red].append(
                        red_strings["shfl_red"].format(red_arg))
                    if threads > 32:
                        var_red = "var_red%d" % stage
                        shr1_red = "share1_red%d" % stage
                        shr2_red = "share2_red%d" % stage
                        template_vals[var_red].append(red_arg)
                        template_vals[shr1_red].append(
                            red_strings["share1_red"].format(red_arg))
                        template_vals[shr2_red].append(
                            red_strings["share2_red"].format(red_arg))

                    # reduction ops are always the last on the stack
                    # just store the register state in the stage output dict
                    stage_out_reg[arg] = red_arg

                else:
                    raise ValueError("Bad op type.")

    if (compute_capability[0] == 3 and compute_capability[1] < 5) or compute_capability[0] < 3:
        template_vals["common"].append(_common_kepler)

    template += _fin_template

    # since we reorderd the operations we need to generate the argument list
    # in the original order
    sig = "P"
    arguments = list()
    unused = 1
    for arg in type_args:
        params = arg_dict.get(arg, False)
        if params:
            sig += params[0]
            arguments.append(params[1])
            del arg_dict[arg]
        # fill in the loop counter for the duplicate reductions that were
        # removed
        elif arg[0] in _reduction_ops:
            sig += "i"
            arguments.append("const int unused%d" % unused)
            unused += 1

    # convert lists to strings
    template_vals["name"] = "_".join(template_vals["name"])
    template_vals["common"] = "\n".join(template_vals["common"])
    template_vals["arguments"] = ",\n    ".join(arguments)
    template_vals["inits"] = "\n    ".join(template_vals["inits"])
    template_vals["finish"] = "\n".join(template_vals["finish"])

    # add the dynamic placeholders: loads#, ops#, reduction#
    for key in placeholders:
        template_vals[key] = "\n        ".join(template_vals[key])

    # populate the template
    code = template % template_vals

    # debugging:
    # print "Compiling %s" % template_vals["name"]
    # f = open("kernel.cu", "w")
    # f = open("%s.cu" % template_vals["name"], "w")
    # print >>f, code
    # f.close()

    # ,"-G" , keep=False
    # module = SourceModule(code, options=["--use_fast_math"])
    module = SourceModule(code, options=[])
    kernel = module.get_function(template_vals["name"])
    kernel.name = template_vals["name"]
    kernel.prepare(sig)

    return kernel
Exemplo n.º 20
0
def _get_lut_bprop_kernel(dtype, deterministic=False):
    """
    Builds the bprop kernel for lookup table layers based on templated code.
    If the deterministic version is requested, an index buffer must be passed
    as an argument. This index buffer re-orders items in the input tensor
    so that word_ids are sorted. This is required since we need to be sure that
    each thread only updates weights for one word id.

    Arguments:
        dtype (np.dtype): The data which the kernel will operate on.
        deterministic (boolean): Builds the deterministic kernel when this is
            set to True.
    """
    if not deterministic:
        code = r"""
__global__ void lut_bprop(
    int* inputs, %(type)s* dW, %(type)s* errors, const int nin,
    const int embedding_dim, const int vocab_size, const int pad_idx)
{
    const int tid  = threadIdx.x;
    const int bid  = blockIdx.x;

    int word_id = inputs[bid];
    int error_row = bid * embedding_dim;
    int output_row = word_id * embedding_dim;

    if(word_id != pad_idx)
    {
        for(int i = tid; i < embedding_dim; i += blockDim.x)
        {
            atomicAdd(&dW[output_row + i], errors[error_row + i]);
        }
    }
}
"""

        code = code % {"type": _ew_types[dtype]["type"]}

        module = SourceModule(code, options=["--use_fast_math"])
        kernel = module.get_function("lut_bprop")
        kernel.prepare("PPPIIIi")
    else:
        code = r"""
__global__ void lut_bprop(
    int* inputs, int* index_buffer, %(type)s* dW, %(type)s* errors,
    const int nin, const int embedding_dim, const int vocab_size,
    const int pad_idx)
{
    const int tid  = threadIdx.x;
    const int bid  = blockIdx.x;

    int index_position = bid;
    int index = index_buffer[index_position];
    int word_id = inputs[index];

    if((bid == 0 || word_id != inputs[index_buffer[bid - 1]]) && word_id != pad_idx)
    {
        int output_row = word_id * embedding_dim;

        do {
            int error_row = index * embedding_dim;

            for(int i = tid; i < embedding_dim; i += blockDim.x)
            {
                dW[output_row + i] += errors[error_row + i];
            }

            index_position++;
            if(index_position == gridDim.x)
            {
                break;
            }
            index = index_buffer[index_position];
        } while(inputs[index] == word_id);
    }
}
"""

        code = code % {"type": _ew_types[dtype]["type"]}

        module = SourceModule(code, options=["--use_fast_math"])
        kernel = module.get_function("lut_bprop")
        kernel.prepare("PPPPIIIi")

    kernel.name = "lut_bprop"
    return kernel
Exemplo n.º 21
0
def _get_bprop_roipooling(clss):

    code = r"""
__global__ void bprop_roipooling(const int nthreads,
    const int num_rois, const int img_count,
    const int channels, const int height, const int width,
    const int pooled_height, const int pooled_width,
    const float* top_diff, const float* bottom_rois, float* bottom_diff,
    const int* argmax_data, const float spatial_scale) {
    for (int index = blockIdx.x * blockDim.x + threadIdx.x; \
        index < (nthreads); index += blockDim.x * gridDim.x){
        // (c, h, w, n) coords in bottom data on feature map
        int n = index % img_count;
        int w = (index / img_count) % width;
        int h = (index / img_count / width) % height;
        int c = index / img_count/ width / height;

        float gradient = 0;
        // Accumulate gradient over all ROIs that pooled this element
        for (int roi_n = 0; roi_n < num_rois; ++roi_n) {
          const float* offset_bottom_rois = bottom_rois + roi_n * 5;
          int roi_batch_ind = offset_bottom_rois[0];
          // Skip if ROI's batch index doesn't match n
          if (n != roi_batch_ind) {
            continue;
          }

          int roi_start_w = round(offset_bottom_rois[1] * spatial_scale);
          int roi_start_h = round(offset_bottom_rois[2] * spatial_scale);
          int roi_end_w = round(offset_bottom_rois[3] * spatial_scale);
          int roi_end_h = round(offset_bottom_rois[4] * spatial_scale);

          // Skip if ROI doesn't include (h, w)
          const bool in_roi = (w >= roi_start_w && w <= roi_end_w &&
                               h >= roi_start_h && h <= roi_end_h);
          if (!in_roi) {
            continue;
          }

          int offset = c * pooled_height * pooled_width * num_rois;
          const float* offset_top_diff = top_diff + offset;
          const int* offset_argmax_data = argmax_data + offset;

          // Compute feasible set of pooled units that could have pooled
          // this bottom unit

          // Force malformed ROIs to be 1x1
          int roi_width = max(roi_end_w - roi_start_w + 1, 1);
          int roi_height = max(roi_end_h - roi_start_h + 1, 1);

          float bin_size_h = static_cast<float>(roi_height)
                             / static_cast<float>(pooled_height);
          float bin_size_w = static_cast<float>(roi_width)
                             / static_cast<float>(pooled_width);

          int phstart = floor(static_cast<float>(h - roi_start_h) / bin_size_h);
          int phend = ceil(static_cast<float>(h - roi_start_h + 1) / bin_size_h);
          int pwstart = floor(static_cast<float>(w - roi_start_w) / bin_size_w);
          int pwend = ceil(static_cast<float>(w - roi_start_w + 1) / bin_size_w);

          phstart = min(max(phstart, 0), pooled_height);
          phend = min(max(phend, 0), pooled_height);
          pwstart = min(max(pwstart, 0), pooled_width);
          pwend = min(max(pwend, 0), pooled_width);

          for (int ph = phstart; ph < phend; ++ph) {
            for (int pw = pwstart; pw < pwend; ++pw) {
              int top_index = ph * pooled_width * num_rois + pw * num_rois + roi_n;
              int bottom_index = h * width * img_count + w * img_count + roi_batch_ind;
              if (offset_argmax_data[top_index] == bottom_index) {
                gradient += offset_top_diff[top_index];
              }
            }
          }
        }
        bottom_diff[index] = gradient;
    }
}

"""

    module = SourceModule(code)
    kernel = module.get_function("bprop_roipooling")
    sig = "8I 4P 1f"
    kernel.prepare(sig)
    return kernel