def backward(ctx, grad_output):
        grad_output = grad_output.contiguous()
        torch.cuda.nvtx.range_push("sync_BN_bw")
        # mini batch mean & var are calculated by forward path.
        # mu = 1./N*np.sum(h, axis = 0)
        # var = 1./N*np.sum((h-mu)**2, axis = 0)
        saved_input, weight, running_mean, running_variance = ctx.saved_tensors
        eps = ctx.eps
        process_group = ctx.process_group
        world_size = ctx.world_size
        grad_input = grad_weight = grad_bias = None

        # TODO(jie): why do I have to clone here? life time of grad_output?
        mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output, saved_input, running_mean, running_variance, weight, eps)

        # calculate grad_input
        if ctx.needs_input_grad[0]:

            if torch.distributed.is_initialized():
                torch.distributed.all_reduce(
                    mean_dy, torch.distributed.reduce_op.SUM, process_group)
                mean_dy = mean_dy / world_size
                torch.distributed.all_reduce(
                    mean_dy_xmu, torch.distributed.reduce_op.SUM, process_group)
                mean_dy_xmu = mean_dy_xmu / world_size
            grad_input = syncbn.batchnorm_backward(grad_output, saved_input, running_mean, running_variance, weight, mean_dy, mean_dy_xmu, eps)

        if weight is None or not ctx.needs_input_grad[1]:
            grad_weight = None

        if weight is None or not ctx.needs_input_grad[2]:
            grad_bias = None

        torch.cuda.nvtx.range_pop()
        return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
    def backward(ctx, grad_output):
        grad_output = grad_output.contiguous()
        torch.cuda.nvtx.range_push("sync_BN_bw")
        # mini batch mean & var are calculated by forward path.
        # mu = 1./N*np.sum(h, axis = 0)
        # var = 1./N*np.sum((h-mu)**2, axis = 0)
        saved_input, weight, mean, inv_std, z, bias = ctx.saved_tensors
        process_group = ctx.process_group
        channel_last = ctx.channel_last
        world_size = ctx.world_size
        fuse_relu = ctx.fuse_relu
        grad_input = grad_z = grad_weight = grad_bias = None

        if fuse_relu:
            grad_output = syncbn.relu_bw_c_last(grad_output, saved_input, z,
                                                mean, inv_std, weight, bias)
        if isinstance(z, torch.Tensor) and ctx.needs_input_grad[1]:
            grad_z = grad_output.clone()

        # TODO(jie): why do I have to clone here? life time of grad_output?
        if channel_last:
            mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn_c_last(
                grad_output, saved_input, mean, inv_std, weight)
        else:
            mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(
                grad_output, saved_input, mean, inv_std, weight)

        # calculate grad_input
        if ctx.needs_input_grad[0]:

            if torch.distributed.is_initialized():
                torch.distributed.all_reduce(mean_dy, ReduceOp.SUM,
                                             process_group)
                mean_dy = mean_dy / world_size
                torch.distributed.all_reduce(mean_dy_xmu, ReduceOp.SUM,
                                             process_group)
                mean_dy_xmu = mean_dy_xmu / world_size
            if channel_last:
                grad_input = syncbn.batchnorm_backward_c_last(
                    grad_output, saved_input, mean, inv_std, weight, mean_dy,
                    mean_dy_xmu)
            else:
                grad_input = syncbn.batchnorm_backward(grad_output,
                                                       saved_input, mean,
                                                       inv_std, weight,
                                                       mean_dy, mean_dy_xmu)

        if weight is None or not ctx.needs_input_grad[2]:
            grad_weight = None

        if weight is None or not ctx.needs_input_grad[3]:
            grad_bias = None

        torch.cuda.nvtx.range_pop()
        return grad_input, grad_z, grad_weight, grad_bias, None, None, None, None, None, None, None, None
    def backward(ctx, grad_output):
        grad_output = grad_output.contiguous()
        # mini batch mean & var are calculated by forward path.
        # mu = 1./N*np.sum(h, axis = 0)
        # var = 1./N*np.sum((h-mu)**2, axis = 0)
        saved_input, weight, mean, inv_std, z, bias, count = ctx.saved_tensors
        process_group = ctx.process_group
        channel_last = ctx.channel_last
        world_size = ctx.world_size
        fuse_relu = ctx.fuse_relu
        grad_input = grad_z = grad_weight = grad_bias = None

        if fuse_relu:
            grad_output = syncbn.relu_bw_c_last(grad_output, saved_input, z,
                                                mean, inv_std, weight, bias)
        if isinstance(z, torch.Tensor) and ctx.needs_input_grad[1]:
            grad_z = grad_output.clone()

        # TODO: update kernel to not pre_divide by item_num
        if channel_last:
            sum_dy, sum_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn_c_last(
                grad_output, saved_input, mean, inv_std, weight)
        else:
            sum_dy, sum_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(
                grad_output, saved_input, mean, inv_std, weight)

        # calculate grad_input
        if ctx.needs_input_grad[0]:

            if torch.distributed.is_initialized():
                num_channels = sum_dy.shape[0]
                combined = torch.cat([sum_dy, sum_dy_xmu], dim=0)
                torch.distributed.all_reduce(combined,
                                             torch.distributed.ReduceOp.SUM,
                                             process_group,
                                             async_op=False)
                sum_dy, sum_dy_xmu = torch.split(combined, num_channels)

            if channel_last:
                grad_input = syncbn.batchnorm_backward_c_last(
                    grad_output, saved_input, mean, inv_std, weight, sum_dy,
                    sum_dy_xmu, count)
            else:
                grad_input = syncbn.batchnorm_backward(grad_output,
                                                       saved_input, mean,
                                                       inv_std, weight, sum_dy,
                                                       sum_dy_xmu, count)

        if weight is None or not ctx.needs_input_grad[2]:
            grad_weight = None

        if weight is None or not ctx.needs_input_grad[3]:
            grad_bias = None

        return grad_input, grad_z, grad_weight, grad_bias, None, None, None, None, None, None, None, None
예제 #4
0
    compare("comparing bn output: ", out_bn, out_r, error)

grad_output_t = type_tensor(grad)

grad_output_r = ref_tensor(grad.transpose(1, 0, 2, 3).reshape(feature_size, -1))
grad_output2_r = ref_tensor(grad)

grad_bias_r = grad_output_r.sum(1)
grad_weight_r = ((inp2_r - m.view(-1, 1, 1)) * torch.rsqrt(b_v.view(-1,1,1) + eps) * grad_output2_r).transpose(1,0).contiguous().view(feature_size, -1).sum(1)

mean_dy_r = grad_output_r.mean(1)
mean_dy_xmu_r = ((inp2_r - m.view(-1, 1, 1)) * grad_output2_r).transpose(1,0).contiguous().view(feature_size, -1).mean(1)

grad_input_r = (grad_output2_r - mean_dy_r.view(-1, 1, 1) - (inp2_r - m.view(-1, 1, 1)) / (b_v.view(-1,1,1) + eps) * mean_dy_xmu_r.view(-1, 1, 1) ) * torch.rsqrt(b_v.view(-1,1,1) + eps) * weight_r.view(-1,1,1)

mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output_t, inp_t, mean, inv_std, weight_t)
grad_input = syncbn.batchnorm_backward(grad_output_t, inp_t, mean, inv_std, weight_t, mean_dy, mean_dy_xmu)
if args.local_rank == 0:
    sbn_result = compare("comparing bias grad: ", grad_bias, grad_bias_r, error) and sbn_result
    sbn_result = compare("comparing weight grad: ", grad_weight, grad_weight_r, error) and sbn_result
    sbn_result = compare("comparing mean_dy grad: ", mean_dy, mean_dy_r, error) and sbn_result
    sbn_result = compare("comparing mean_dy_xmu grad: ", mean_dy_xmu, mean_dy_xmu_r, error) and sbn_result
    sbn_result = compare("comparing input grad: ", grad_input, grad_input_r, error) and sbn_result
    compare("comparing bn input grad: ", inp_bn.grad, grad_input_r, error)

if args.local_rank == 0:
    sbn_result = compare("comparing running_mean: ", bn.running_mean.data, sbn.module.running_mean.data, error) and sbn_result
    sbn_result = compare("comparing running_variance: ", bn.running_var.data, sbn.module.running_var.data, error) and sbn_result

# execute by both
compare("comparing layers output: ", out_bn[start:finish], out_sbn, error) and sbn_result
예제 #5
0
grad_weight_r = ((inp2_r - m.view(-1, 1, 1)) *
                 torch.rsqrt(b_v.view(-1, 1, 1) + eps) *
                 grad_output2_r).transpose(1, 0).contiguous().view(
                     feature_size, -1).sum(1)

mean_dy_r = grad_output_r.mean(1)
mean_dy_xmu_r = ((inp2_r - m.view(-1, 1, 1)) * grad_output2_r).transpose(
    1, 0).contiguous().view(feature_size, -1).mean(1)

grad_input_r = (grad_output2_r - mean_dy_r.view(-1, 1, 1) -
                (inp2_r - m.view(-1, 1, 1)) /
                (b_v.view(-1, 1, 1) + eps) * mean_dy_xmu_r.view(-1, 1, 1)
                ) * torch.rsqrt(b_v.view(-1, 1, 1) + eps) * weight_r.view(
                    -1, 1, 1)

mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(
    grad_output_t, inp_t, mean, var_biased, weight_t, eps)
grad_input = syncbn.batchnorm_backward(grad_output_t, inp_t, mean, var_biased,
                                       weight_t, mean_dy, mean_dy_xmu, eps)
sbn_result = compare("comparing bias grad: ", grad_bias, grad_bias_r,
                     error) and sbn_result
sbn_result = compare("comparing weight grad: ", grad_weight, grad_weight_r,
                     error) and sbn_result
sbn_result = compare("comparing mean_dy grad: ", mean_dy, mean_dy_r,
                     error) and sbn_result
sbn_result = compare("comparing mean_dy_xmu grad: ", mean_dy_xmu,
                     mean_dy_xmu_r, error) and sbn_result
sbn_result = compare("comparing input grad: ", grad_input, grad_input_r,
                     error) and sbn_result
compare("comparing bn input grad: ", inp_bn.grad, grad_input_r, error)
sbn_result = compare("comparing sbn input grad: ", inp_sbn.grad, grad_input_r,
                     error) and sbn_result