예제 #1
0
def _get_bn_fprop_kernel(dtype, threads, compute_capability):

    if threads > 32:
        shr_code = "__shared__ float sPartials[THREADS];"
        red_code = r"""
    sPartials[tid] = xvar;
    __syncthreads();

    #pragma unroll
    for (int a = THREADS >> 1; a > 32; a >>= 1)
    {
        if ( tid < a )
            sPartials[tid] += sPartials[tid + a];
        __syncthreads();
    }
    if ( tid < 32 )
    {
        xvar = sPartials[tid] + sPartials[tid + 32];
        #pragma unroll
        for (int i = 16; i > 0; i >>= 1)
            xvar += __shfl_xor(xvar, i);

        sPartials[tid] = xvar * rcpN;
    }
    __syncthreads();
    xvar = sPartials[0];
"""
    else:
        shr_code = ""
        red_code = r"""
    #pragma unroll
    for (int i = 16; i > 0; i >>= 1)
        xvar += __shfl_xor(xvar, i);
    xvar *= rcpN;
"""

    code = r"""
#define THREADS %(threads)s

%(common)s
%(binary)s

__global__ void batchnorm_fprop (
    %(type)s* y_out, float* xvar_out, float* gmean_out, float* gvar_out,
    const %(type)s* x_in, const float* xsum_in, const float* gmean_in,
    const float* gvar_in, const float* gamma_in, const float* beta_in,
    const float eps, const float rho, const float accumbeta, const int N,
    const int relu, bool binary)
{
    %(share)s

    const int tid  = threadIdx.x;
    const int bid  = blockIdx.x;
    int offset = bid * N;

    const %(type)s* x_in0 = x_in + offset + tid;

    const float rcpN = 1.0f/(float)N;

    float xmean = __ldg(xsum_in + bid) * rcpN;

    float xvar = 0.0f;
    for (int i = tid; i < N; i += THREADS)
    {
        float x = %(cvt)s(__ldg(x_in0));
        x_in0 += THREADS;

        x -= xmean;
        if (binary) {
            xvar += shift_element(x, x, true);
        } else {
            xvar += x * x;
        }
    }
    %(red)s

    float gamma = __ldg(gamma_in + bid);
    float beta  = __ldg(beta_in  + bid);

    if ( tid == 0 )
    {
        float gmean = __ldg(gmean_in + bid);
        float gvar  = __ldg(gvar_in  + bid);

        *(xvar_out  + bid) = xvar;
        *(gmean_out + bid) = gmean * rho + (1.0f - rho) * xmean;
        *(gvar_out  + bid) = gvar  * rho + (1.0f - rho) * xvar;
    }

    float xvar_rcp_sqrt = 1.0f / sqrtf(xvar + eps);

    int start = N - (THREADS*4 - tid);
    offset += start;
    x_in   += offset;
    y_out  += offset;

    for (int i = start; i >= -THREADS*3; i -= THREADS*4)
    {
        float x0 = i >= -THREADS*0 ? %(cvt)s(__ldg(x_in + THREADS*0)) : 0.0f;
        float x1 = i >= -THREADS*1 ? %(cvt)s(__ldg(x_in + THREADS*1)) : 0.0f;
        float x2 = i >= -THREADS*2 ? %(cvt)s(__ldg(x_in + THREADS*2)) : 0.0f;
        float x3 =                   %(cvt)s(__ldg(x_in + THREADS*3));

        x_in -= THREADS*4;

        float xhat0 = 0.0f;
        float xhat1 = 0.0f;
        float xhat2 = 0.0f;
        float xhat3 = 0.0f;

        float y0 = 0.0f;
        float y1 = 0.0f;
        float y2 = 0.0f;
        float y3 = 0.0f;
        if (binary) {
            xhat0 = shift_element(x0 - xmean, xvar_rcp_sqrt, true);
            xhat1 = shift_element(x1 - xmean, xvar_rcp_sqrt, true);
            xhat2 = shift_element(x2 - xmean, xvar_rcp_sqrt, true);
            xhat3 = shift_element(x3 - xmean, xvar_rcp_sqrt, true);

            y0 = shift_element(xhat0, gamma, true) + beta;
            y1 = shift_element(xhat1, gamma, true) + beta;
            y2 = shift_element(xhat2, gamma, true) + beta;
            y3 = shift_element(xhat3, gamma, true) + beta;
        } else {
            xhat0 = (x0 - xmean) * xvar_rcp_sqrt;
            xhat1 = (x1 - xmean) * xvar_rcp_sqrt;
            xhat2 = (x2 - xmean) * xvar_rcp_sqrt;
            xhat3 = (x3 - xmean) * xvar_rcp_sqrt;

            y0 = xhat0 * gamma + beta;
            y1 = xhat1 * gamma + beta;
            y2 = xhat2 * gamma + beta;
            y3 = xhat3 * gamma + beta;
        }

        if (relu)
        {
            y0 = fmaxf(y0, 0.0f);
            y1 = fmaxf(y1, 0.0f);
            y2 = fmaxf(y2, 0.0f);
            y3 = fmaxf(y3, 0.0f);
        }

        %(y0_out)s
        %(y1_out)s
        %(y2_out)s
        %(y3_out)s
        if (accumbeta == 0.0)
        {
            if (i >= -THREADS*0) *(y_out + THREADS*0) = y0_val;
            if (i >= -THREADS*1) *(y_out + THREADS*1) = y1_val;
            if (i >= -THREADS*2) *(y_out + THREADS*2) = y2_val;
                                 *(y_out + THREADS*3) = y3_val;
        }
        else
        {
            if (i >= -THREADS*0) *(y_out + THREADS*0) = y_out[THREADS*0] * accumbeta + y0_val;
            if (i >= -THREADS*1) *(y_out + THREADS*1) = y_out[THREADS*1] * accumbeta + y1_val;
            if (i >= -THREADS*2) *(y_out + THREADS*2) = y_out[THREADS*2] * accumbeta + y2_val;
                                 *(y_out + THREADS*3) = y_out[THREADS*3] * accumbeta + y3_val;
        }
        y_out -= THREADS*4;
    }
}
"""
    out_code = _ew_strings["round"]["nearest"].get(dtype, "float {0} = {1};")
    common_code  = _common_round["nearest"].get(dtype, "")
    if dtype == "f2":
        common_code += _common_fp16_to_fp32

    if (compute_capability[0] == 3 and compute_capability[1] < 5) or compute_capability[0] < 3:
        common_code += _common_kepler

    code = code % {
        "common"    : common_code,
        "binary"    : shift_element(),
        "share"     : shr_code,
        "red"       : red_code,
        "threads"   : threads,
        "type"      : _ew_types[dtype]["type"],
        "cvt"       : _ew_types[dtype]["cvt"],
        "y0_out"    : out_code.format("y0_val",     "y0"),
        "y1_out"    : out_code.format("y1_val",     "y1"),
        "y2_out"    : out_code.format("y2_val",     "y2"),
        "y3_out"    : out_code.format("y3_val",     "y3"),
    }
    module = SourceModule(code, options=["--use_fast_math"])
    kernel = module.get_function("batchnorm_fprop")
    kernel.prepare("PPPPPPPPPPfffIII")
    kernel.name = "batchnorm_fprop"
    return kernel
예제 #2
0
def _get_bn_fprop_kernel(dtype, threads, compute_capability):

    if threads > 32:
        shr_code = "__shared__ float sPartials[THREADS];"
        red_code = r"""
    sPartials[tid] = xvar;
    __syncthreads();

    #pragma unroll
    for (int a = THREADS >> 1; a > 32; a >>= 1)
    {
        if ( tid < a )
            sPartials[tid] += sPartials[tid + a];
        __syncthreads();
    }
    if ( tid < 32 )
    {
        xvar = sPartials[tid] + sPartials[tid + 32];
        #pragma unroll
        for (int i = 16; i > 0; i >>= 1)
            xvar += __shfl_xor(xvar, i);

        sPartials[tid] = xvar * rcpN;
    }
    __syncthreads();
    xvar = sPartials[0];
"""
    else:
        shr_code = ""
        red_code = r"""
    #pragma unroll
    for (int i = 16; i > 0; i >>= 1)
        xvar += __shfl_xor(xvar, i);
    xvar *= rcpN;
"""

    code = r"""
#define THREADS %(threads)s

%(common)s
%(binary)s

__global__ void batchnorm_fprop (
    %(type)s* y_out, float* xvar_out, float* gmean_out, float* gvar_out,
    const %(type)s* x_in, const float* xsum_in, const float* gmean_in,
    const float* gvar_in, const float* gamma_in, const float* beta_in,
    const float eps, const float rho, const float accumbeta, const int N,
    const int relu, bool binary)
{
    %(share)s

    const int tid  = threadIdx.x;
    const int bid  = blockIdx.x;
    int offset = bid * N;

    const %(type)s* x_in0 = x_in + offset + tid;

    const float rcpN = 1.0f/(float)N;

    float xmean = __ldg(xsum_in + bid) * rcpN;

    float xvar = 0.0f;
    for (int i = tid; i < N; i += THREADS)
    {
        float x = %(cvt)s(__ldg(x_in0));
        x_in0 += THREADS;

        x -= xmean;
        if (binary) {
            xvar += shift_element(x, x, true);
        } else {
            xvar += x * x;
        }
    }
    %(red)s

    float gamma = __ldg(gamma_in + bid);
    float beta  = __ldg(beta_in  + bid);

    if ( tid == 0 )
    {
        float gmean = __ldg(gmean_in + bid);
        float gvar  = __ldg(gvar_in  + bid);

        *(xvar_out  + bid) = xvar;
        *(gmean_out + bid) = gmean * rho + (1.0f - rho) * xmean;
        *(gvar_out  + bid) = gvar  * rho + (1.0f - rho) * xvar;
    }

    float xvar_rcp_sqrt = 1.0f / sqrtf(xvar + eps);

    int start = N - (THREADS*4 - tid);
    offset += start;
    x_in   += offset;
    y_out  += offset;

    for (int i = start; i >= -THREADS*3; i -= THREADS*4)
    {
        float x0 = i >= -THREADS*0 ? %(cvt)s(__ldg(x_in + THREADS*0)) : 0.0f;
        float x1 = i >= -THREADS*1 ? %(cvt)s(__ldg(x_in + THREADS*1)) : 0.0f;
        float x2 = i >= -THREADS*2 ? %(cvt)s(__ldg(x_in + THREADS*2)) : 0.0f;
        float x3 =                   %(cvt)s(__ldg(x_in + THREADS*3));

        x_in -= THREADS*4;

        float xhat0 = 0.0f;
        float xhat1 = 0.0f;
        float xhat2 = 0.0f;
        float xhat3 = 0.0f;

        float y0 = 0.0f;
        float y1 = 0.0f;
        float y2 = 0.0f;
        float y3 = 0.0f;
        if (binary) {
            xhat0 = shift_element(x0 - xmean, xvar_rcp_sqrt, true);
            xhat1 = shift_element(x1 - xmean, xvar_rcp_sqrt, true);
            xhat2 = shift_element(x2 - xmean, xvar_rcp_sqrt, true);
            xhat3 = shift_element(x3 - xmean, xvar_rcp_sqrt, true);

            y0 = shift_element(xhat0, gamma, true) + beta;
            y1 = shift_element(xhat1, gamma, true) + beta;
            y2 = shift_element(xhat2, gamma, true) + beta;
            y3 = shift_element(xhat3, gamma, true) + beta;
        } else {
            xhat0 = (x0 - xmean) * xvar_rcp_sqrt;
            xhat1 = (x1 - xmean) * xvar_rcp_sqrt;
            xhat2 = (x2 - xmean) * xvar_rcp_sqrt;
            xhat3 = (x3 - xmean) * xvar_rcp_sqrt;

            y0 = xhat0 * gamma + beta;
            y1 = xhat1 * gamma + beta;
            y2 = xhat2 * gamma + beta;
            y3 = xhat3 * gamma + beta;
        }

        if (relu)
        {
            y0 = fmaxf(y0, 0.0f);
            y1 = fmaxf(y1, 0.0f);
            y2 = fmaxf(y2, 0.0f);
            y3 = fmaxf(y3, 0.0f);
        }

        %(y0_out)s
        %(y1_out)s
        %(y2_out)s
        %(y3_out)s
        if (accumbeta == 0.0)
        {
            if (i >= -THREADS*0) *(y_out + THREADS*0) = y0_val;
            if (i >= -THREADS*1) *(y_out + THREADS*1) = y1_val;
            if (i >= -THREADS*2) *(y_out + THREADS*2) = y2_val;
                                 *(y_out + THREADS*3) = y3_val;
        }
        else
        {
            if (i >= -THREADS*0) *(y_out + THREADS*0) = y_out[THREADS*0] * accumbeta + y0_val;
            if (i >= -THREADS*1) *(y_out + THREADS*1) = y_out[THREADS*1] * accumbeta + y1_val;
            if (i >= -THREADS*2) *(y_out + THREADS*2) = y_out[THREADS*2] * accumbeta + y2_val;
                                 *(y_out + THREADS*3) = y_out[THREADS*3] * accumbeta + y3_val;
        }
        y_out -= THREADS*4;
    }
}
"""
    out_code = _ew_strings["round"]["nearest"].get(dtype, "float {0} = {1};")
    common_code = _common_round["nearest"].get(dtype, "")
    if dtype == "f2":
        common_code += _common_fp16_to_fp32

    if (compute_capability[0] == 3
            and compute_capability[1] < 5) or compute_capability[0] < 3:
        common_code += _common_kepler

    code = code % {
        "common": common_code,
        "binary": shift_element(),
        "share": shr_code,
        "red": red_code,
        "threads": threads,
        "type": _ew_types[dtype]["type"],
        "cvt": _ew_types[dtype]["cvt"],
        "y0_out": out_code.format("y0_val", "y0"),
        "y1_out": out_code.format("y1_val", "y1"),
        "y2_out": out_code.format("y2_val", "y2"),
        "y3_out": out_code.format("y3_val", "y3"),
    }
    module = SourceModule(code, options=["--use_fast_math"])
    kernel = module.get_function("batchnorm_fprop")
    kernel.prepare("PPPPPPPPPPfffIII")
    kernel.name = "batchnorm_fprop"
    return kernel
예제 #3
0
def _get_bn_bprop_kernel(dtype, threads, compute_capability):

    if threads > 32:
        shr_code = "__shared__ float sPartials[THREADS * 2];"
        red_code = r"""
    sPartials[tid + THREADS*0] = grad_gamma;
    sPartials[tid + THREADS*1] = grad_beta;
    __syncthreads();

    #pragma unroll
    for (int a = THREADS >> 1; a > 32; a >>= 1)
    {
        if ( tid < a )
        {
            sPartials[tid + THREADS*0] += sPartials[tid + a + THREADS*0];
            sPartials[tid + THREADS*1] += sPartials[tid + a + THREADS*1];
        }
        __syncthreads();
    }
    if ( tid < 32 )
    {
        grad_gamma = sPartials[tid + THREADS*0] + sPartials[tid + 32 + THREADS*0];
        grad_beta  = sPartials[tid + THREADS*1] + sPartials[tid + 32 + THREADS*1];

        #pragma unroll
        for (int i = 16; i > 0; i >>= 1)
        {
            grad_gamma += __shfl_xor(grad_gamma, i);
            grad_beta  += __shfl_xor(grad_beta,  i);
        }
        sPartials[tid + THREADS*0] = grad_gamma;
        sPartials[tid + THREADS*1] = grad_beta;
    }
    __syncthreads();
    grad_gamma = sPartials[THREADS*0];
    grad_beta  = sPartials[THREADS*1];
"""
    else:
        shr_code = ""
        red_code = r"""
    #pragma unroll
    for (int i = 16; i > 0; i >>= 1)
    {
        grad_gamma += __shfl_xor(grad_gamma, i);
        grad_beta  += __shfl_xor(grad_beta,  i);
    }
"""

    code = r"""
#define THREADS %(threads)s

%(common)s
%(binary)s

__global__ void batchnorm_bprop (
    %(type)s* delta_out, float* grad_gamma_out, float* grad_beta_out,
    const %(type)s* delta_in, const %(type)s* x_in, const float* xsum_in,
    const float* xvar_in, const float* gamma_in,
    const float eps, const int N, bool binary)
{
    %(share)s

    const int tid  = threadIdx.x;
    const int bid  = blockIdx.x;
    const float rcpN = 1.0f/(float)N;
    int offset = bid * N;

    const %(type)s* x_in0 = x_in     + offset + tid;
    const %(type)s* d_in0 = delta_in + offset + tid;

    float xmean = __ldg(xsum_in  + bid) * rcpN;
    float xvar  = __ldg(xvar_in  + bid);
    float gamma = __ldg(gamma_in + bid);

    float xvar_rcp_sqrt = 1.0f / sqrtf(xvar + eps);
    float grad_gamma    = 0.0f;
    float grad_beta     = 0.0f;

    for (int i = tid; i < N; i += THREADS)
    {
        float x = %(cvt)s(__ldg(x_in0));
        x_in0 += THREADS;
        float d = %(cvt)s(__ldg(d_in0));
        d_in0 += THREADS;

        float xhat = 0.0f;
        if (binary) {
            xhat = shift_element(x - xmean, xvar_rcp_sqrt, true);
        } else {
            xhat = (x - xmean) * xvar_rcp_sqrt;
        }

        grad_gamma += xhat * d;
        grad_beta  += d;
    }
    %(red)s

    if ( tid == 0 )
    {
        *(grad_gamma_out + bid) = grad_gamma;
        *(grad_beta_out  + bid) = grad_beta;
    }

    int start = N - (THREADS*4 - tid);
    offset += start;
    const %(type)s* x_in1 = x_in     + offset;
    const %(type)s* d_in1 = delta_in + offset;
    delta_out += offset;

    for (int i = start; i >= -THREADS*3; i -= THREADS*4)
    {
        float x0 = i >= -THREADS*0 ? %(cvt)s(__ldg(x_in1 + THREADS*0)) : 0.0f;
        float x1 = i >= -THREADS*1 ? %(cvt)s(__ldg(x_in1 + THREADS*1)) : 0.0f;
        float x2 = i >= -THREADS*2 ? %(cvt)s(__ldg(x_in1 + THREADS*2)) : 0.0f;
        float x3 =                   %(cvt)s(__ldg(x_in1 + THREADS*3));

        float d0 = i >= -THREADS*0 ? %(cvt)s(__ldg(d_in1 + THREADS*0)) : 0.0f;
        float d1 = i >= -THREADS*1 ? %(cvt)s(__ldg(d_in1 + THREADS*1)) : 0.0f;
        float d2 = i >= -THREADS*2 ? %(cvt)s(__ldg(d_in1 + THREADS*2)) : 0.0f;
        float d3 =                   %(cvt)s(__ldg(d_in1 + THREADS*3));

        x_in1 -= THREADS*4;
        d_in1 -= THREADS*4;

        float xhat0 = 0.0f;
        float xhat1 = 0.0f;
        float xhat2 = 0.0f;
        float xhat3 = 0.0f;

        float xtmp0 = 0.0f;
        float xtmp1 = 0.0f;
        float xtmp2 = 0.0f;
        float xtmp3 = 0.0f;

        float delta0 = 0.0f;
        float delta1 = 0.0f;
        float delta2 = 0.0f;
        float delta3 = 0.0f;

        if (binary) {
            xhat0 = shift_element(x0 - xmean, xvar_rcp_sqrt, true);
            xhat1 = shift_element(x1 - xmean, xvar_rcp_sqrt, true);
            xhat2 = shift_element(x2 - xmean, xvar_rcp_sqrt, true);
            xhat3 = shift_element(x3 - xmean, xvar_rcp_sqrt, true);

            xtmp0 = (shift_element(xhat0, grad_gamma, true) + grad_beta) * rcpN;
            xtmp1 = (shift_element(xhat1, grad_gamma, true) + grad_beta) * rcpN;
            xtmp2 = (shift_element(xhat2, grad_gamma, true) + grad_beta) * rcpN;
            xtmp3 = (shift_element(xhat3, grad_gamma, true) + grad_beta) * rcpN;

            delta0 = shift_element(shift_element(d0 - xtmp0, gamma, true), xvar_rcp_sqrt, true);
            delta1 = shift_element(shift_element(d1 - xtmp1, gamma, true), xvar_rcp_sqrt, true);
            delta2 = shift_element(shift_element(d2 - xtmp2, gamma, true), xvar_rcp_sqrt, true);
            delta3 = shift_element(shift_element(d3 - xtmp3, gamma, true), xvar_rcp_sqrt, true);
        } else {
            xhat0 = (x0 - xmean) * xvar_rcp_sqrt;
            xhat1 = (x1 - xmean) * xvar_rcp_sqrt;
            xhat2 = (x2 - xmean) * xvar_rcp_sqrt;
            xhat3 = (x3 - xmean) * xvar_rcp_sqrt;

            xtmp0 = (xhat0 * grad_gamma + grad_beta) * rcpN;
            xtmp1 = (xhat1 * grad_gamma + grad_beta) * rcpN;
            xtmp2 = (xhat2 * grad_gamma + grad_beta) * rcpN;
            xtmp3 = (xhat3 * grad_gamma + grad_beta) * rcpN;

            delta0 = gamma * (d0 - xtmp0) * xvar_rcp_sqrt;
            delta1 = gamma * (d1 - xtmp1) * xvar_rcp_sqrt;
            delta2 = gamma * (d2 - xtmp2) * xvar_rcp_sqrt;
            delta3 = gamma * (d3 - xtmp3) * xvar_rcp_sqrt;
        }

        %(delta0_out)s
        %(delta1_out)s
        %(delta2_out)s
        %(delta3_out)s
        if (i >= -THREADS*0) *(delta_out + THREADS*0) = delta0_val;
        if (i >= -THREADS*1) *(delta_out + THREADS*1) = delta1_val;
        if (i >= -THREADS*2) *(delta_out + THREADS*2) = delta2_val;
                             *(delta_out + THREADS*3) = delta3_val;
        delta_out -= THREADS*4;
    }
}
"""
    out_code = _ew_strings["round"]["nearest"].get(dtype, "float {0} = {1};")
    common_code = _common_round["nearest"].get(dtype, "")
    if dtype == "f2":
        common_code += _common_fp16_to_fp32

    if (compute_capability[0] == 3 and compute_capability[1] < 5) or compute_capability[0] < 3:
        common_code += _common_kepler

    code = code % {
        "common"         : common_code,
        "binary"         : shift_element(),
        "share"          : shr_code,
        "red"            : red_code,
        "threads"        : threads,
        "type"           : _ew_types[dtype]["type"],
        "cvt"            : _ew_types[dtype]["cvt"],
        "delta0_out"     : out_code.format("delta0_val",     "delta0"),
        "delta1_out"     : out_code.format("delta1_val",     "delta1"),
        "delta2_out"     : out_code.format("delta2_val",     "delta2"),
        "delta3_out"     : out_code.format("delta3_val",     "delta3"),
    }
    module = SourceModule(code, options=["--use_fast_math"])
    kernel = module.get_function("batchnorm_bprop")
    kernel.prepare("PPPPPPPPfII")
    kernel.name = "batchnorm_bprop"
    return kernel
예제 #4
0
def _get_bn_bprop_kernel(dtype, threads, compute_capability):

    if threads > 32:
        shr_code = "__shared__ float sPartials[THREADS * 2];"
        red_code = r"""
    sPartials[tid + THREADS*0] = grad_gamma;
    sPartials[tid + THREADS*1] = grad_beta;
    __syncthreads();

    #pragma unroll
    for (int a = THREADS >> 1; a > 32; a >>= 1)
    {
        if ( tid < a )
        {
            sPartials[tid + THREADS*0] += sPartials[tid + a + THREADS*0];
            sPartials[tid + THREADS*1] += sPartials[tid + a + THREADS*1];
        }
        __syncthreads();
    }
    if ( tid < 32 )
    {
        grad_gamma = sPartials[tid + THREADS*0] + sPartials[tid + 32 + THREADS*0];
        grad_beta  = sPartials[tid + THREADS*1] + sPartials[tid + 32 + THREADS*1];

        #pragma unroll
        for (int i = 16; i > 0; i >>= 1)
        {
            grad_gamma += __shfl_xor(grad_gamma, i);
            grad_beta  += __shfl_xor(grad_beta,  i);
        }
        sPartials[tid + THREADS*0] = grad_gamma;
        sPartials[tid + THREADS*1] = grad_beta;
    }
    __syncthreads();
    grad_gamma = sPartials[THREADS*0];
    grad_beta  = sPartials[THREADS*1];
"""
    else:
        shr_code = ""
        red_code = r"""
    #pragma unroll
    for (int i = 16; i > 0; i >>= 1)
    {
        grad_gamma += __shfl_xor(grad_gamma, i);
        grad_beta  += __shfl_xor(grad_beta,  i);
    }
"""

    code = r"""
#define THREADS %(threads)s

%(common)s
%(binary)s

__global__ void batchnorm_bprop (
    %(type)s* delta_out, float* grad_gamma_out, float* grad_beta_out,
    const %(type)s* delta_in, const %(type)s* x_in, const float* xsum_in,
    const float* xvar_in, const float* gamma_in,
    const float eps, const int N, bool binary)
{
    %(share)s

    const int tid  = threadIdx.x;
    const int bid  = blockIdx.x;
    const float rcpN = 1.0f/(float)N;
    int offset = bid * N;

    const %(type)s* x_in0 = x_in     + offset + tid;
    const %(type)s* d_in0 = delta_in + offset + tid;

    float xmean = __ldg(xsum_in  + bid) * rcpN;
    float xvar  = __ldg(xvar_in  + bid);
    float gamma = __ldg(gamma_in + bid);

    float xvar_rcp_sqrt = 1.0f / sqrtf(xvar + eps);
    float grad_gamma    = 0.0f;
    float grad_beta     = 0.0f;

    for (int i = tid; i < N; i += THREADS)
    {
        float x = %(cvt)s(__ldg(x_in0));
        x_in0 += THREADS;
        float d = %(cvt)s(__ldg(d_in0));
        d_in0 += THREADS;

        float xhat = 0.0f;
        if (binary) {
            xhat = shift_element(x - xmean, xvar_rcp_sqrt, true);
        } else {
            xhat = (x - xmean) * xvar_rcp_sqrt;
        }

        grad_gamma += xhat * d;
        grad_beta  += d;
    }
    %(red)s

    if ( tid == 0 )
    {
        *(grad_gamma_out + bid) = grad_gamma;
        *(grad_beta_out  + bid) = grad_beta;
    }

    int start = N - (THREADS*4 - tid);
    offset += start;
    const %(type)s* x_in1 = x_in     + offset;
    const %(type)s* d_in1 = delta_in + offset;
    delta_out += offset;

    for (int i = start; i >= -THREADS*3; i -= THREADS*4)
    {
        float x0 = i >= -THREADS*0 ? %(cvt)s(__ldg(x_in1 + THREADS*0)) : 0.0f;
        float x1 = i >= -THREADS*1 ? %(cvt)s(__ldg(x_in1 + THREADS*1)) : 0.0f;
        float x2 = i >= -THREADS*2 ? %(cvt)s(__ldg(x_in1 + THREADS*2)) : 0.0f;
        float x3 =                   %(cvt)s(__ldg(x_in1 + THREADS*3));

        float d0 = i >= -THREADS*0 ? %(cvt)s(__ldg(d_in1 + THREADS*0)) : 0.0f;
        float d1 = i >= -THREADS*1 ? %(cvt)s(__ldg(d_in1 + THREADS*1)) : 0.0f;
        float d2 = i >= -THREADS*2 ? %(cvt)s(__ldg(d_in1 + THREADS*2)) : 0.0f;
        float d3 =                   %(cvt)s(__ldg(d_in1 + THREADS*3));

        x_in1 -= THREADS*4;
        d_in1 -= THREADS*4;

        float xhat0 = 0.0f;
        float xhat1 = 0.0f;
        float xhat2 = 0.0f;
        float xhat3 = 0.0f;

        float xtmp0 = 0.0f;
        float xtmp1 = 0.0f;
        float xtmp2 = 0.0f;
        float xtmp3 = 0.0f;

        float delta0 = 0.0f;
        float delta1 = 0.0f;
        float delta2 = 0.0f;
        float delta3 = 0.0f;

        if (binary) {
            xhat0 = shift_element(x0 - xmean, xvar_rcp_sqrt, true);
            xhat1 = shift_element(x1 - xmean, xvar_rcp_sqrt, true);
            xhat2 = shift_element(x2 - xmean, xvar_rcp_sqrt, true);
            xhat3 = shift_element(x3 - xmean, xvar_rcp_sqrt, true);

            xtmp0 = (shift_element(xhat0, grad_gamma, true) + grad_beta) * rcpN;
            xtmp1 = (shift_element(xhat1, grad_gamma, true) + grad_beta) * rcpN;
            xtmp2 = (shift_element(xhat2, grad_gamma, true) + grad_beta) * rcpN;
            xtmp3 = (shift_element(xhat3, grad_gamma, true) + grad_beta) * rcpN;

            delta0 = shift_element(shift_element(d0 - xtmp0, gamma, true), xvar_rcp_sqrt, true);
            delta1 = shift_element(shift_element(d1 - xtmp1, gamma, true), xvar_rcp_sqrt, true);
            delta2 = shift_element(shift_element(d2 - xtmp2, gamma, true), xvar_rcp_sqrt, true);
            delta3 = shift_element(shift_element(d3 - xtmp3, gamma, true), xvar_rcp_sqrt, true);
        } else {
            xhat0 = (x0 - xmean) * xvar_rcp_sqrt;
            xhat1 = (x1 - xmean) * xvar_rcp_sqrt;
            xhat2 = (x2 - xmean) * xvar_rcp_sqrt;
            xhat3 = (x3 - xmean) * xvar_rcp_sqrt;

            xtmp0 = (xhat0 * grad_gamma + grad_beta) * rcpN;
            xtmp1 = (xhat1 * grad_gamma + grad_beta) * rcpN;
            xtmp2 = (xhat2 * grad_gamma + grad_beta) * rcpN;
            xtmp3 = (xhat3 * grad_gamma + grad_beta) * rcpN;

            delta0 = gamma * (d0 - xtmp0) * xvar_rcp_sqrt;
            delta1 = gamma * (d1 - xtmp1) * xvar_rcp_sqrt;
            delta2 = gamma * (d2 - xtmp2) * xvar_rcp_sqrt;
            delta3 = gamma * (d3 - xtmp3) * xvar_rcp_sqrt;
        }

        %(delta0_out)s
        %(delta1_out)s
        %(delta2_out)s
        %(delta3_out)s
        if (i >= -THREADS*0) *(delta_out + THREADS*0) = delta0_val;
        if (i >= -THREADS*1) *(delta_out + THREADS*1) = delta1_val;
        if (i >= -THREADS*2) *(delta_out + THREADS*2) = delta2_val;
                             *(delta_out + THREADS*3) = delta3_val;
        delta_out -= THREADS*4;
    }
}
"""
    out_code = _ew_strings["round"]["nearest"].get(dtype, "float {0} = {1};")
    common_code = _common_round["nearest"].get(dtype, "")
    if dtype == "f2":
        common_code += _common_fp16_to_fp32

    if (compute_capability[0] == 3
            and compute_capability[1] < 5) or compute_capability[0] < 3:
        common_code += _common_kepler

    code = code % {
        "common": common_code,
        "binary": shift_element(),
        "share": shr_code,
        "red": red_code,
        "threads": threads,
        "type": _ew_types[dtype]["type"],
        "cvt": _ew_types[dtype]["cvt"],
        "delta0_out": out_code.format("delta0_val", "delta0"),
        "delta1_out": out_code.format("delta1_val", "delta1"),
        "delta2_out": out_code.format("delta2_val", "delta2"),
        "delta3_out": out_code.format("delta3_val", "delta3"),
    }
    module = SourceModule(code, options=["--use_fast_math"])
    kernel = module.get_function("batchnorm_bprop")
    kernel.prepare("PPPPPPPPfII")
    kernel.name = "batchnorm_bprop"
    return kernel