def backward(ctx, grad_output): grad_output = grad_output.contiguous() torch.cuda.nvtx.range_push("sync_BN_bw") # mini batch mean & var are calculated by forward path. # mu = 1./N*np.sum(h, axis = 0) # var = 1./N*np.sum((h-mu)**2, axis = 0) saved_input, weight, mean, inv_std, z, bias = ctx.saved_tensors process_group = ctx.process_group channel_last = ctx.channel_last world_size = ctx.world_size fuse_relu = ctx.fuse_relu grad_input = grad_z = grad_weight = grad_bias = None if fuse_relu: grad_output = syncbn.relu_bw_c_last(grad_output, saved_input, z, mean, inv_std, weight, bias) if isinstance(z, torch.Tensor) and ctx.needs_input_grad[1]: grad_z = grad_output.clone() # TODO(jie): why do I have to clone here? life time of grad_output? if channel_last: mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn_c_last( grad_output, saved_input, mean, inv_std, weight) else: mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn( grad_output, saved_input, mean, inv_std, weight) # calculate grad_input if ctx.needs_input_grad[0]: if torch.distributed.is_initialized(): torch.distributed.all_reduce(mean_dy, ReduceOp.SUM, process_group) mean_dy = mean_dy / world_size torch.distributed.all_reduce(mean_dy_xmu, ReduceOp.SUM, process_group) mean_dy_xmu = mean_dy_xmu / world_size if channel_last: grad_input = syncbn.batchnorm_backward_c_last( grad_output, saved_input, mean, inv_std, weight, mean_dy, mean_dy_xmu) else: grad_input = syncbn.batchnorm_backward(grad_output, saved_input, mean, inv_std, weight, mean_dy, mean_dy_xmu) if weight is None or not ctx.needs_input_grad[2]: grad_weight = None if weight is None or not ctx.needs_input_grad[3]: grad_bias = None torch.cuda.nvtx.range_pop() return grad_input, grad_z, grad_weight, grad_bias, None, None, None, None, None, None, None, None
def backward(ctx, grad_output): grad_output = grad_output.contiguous() # mini batch mean & var are calculated by forward path. # mu = 1./N*np.sum(h, axis = 0) # var = 1./N*np.sum((h-mu)**2, axis = 0) saved_input, weight, mean, inv_std, z, bias, count = ctx.saved_tensors process_group = ctx.process_group channel_last = ctx.channel_last world_size = ctx.world_size fuse_relu = ctx.fuse_relu grad_input = grad_z = grad_weight = grad_bias = None if fuse_relu: grad_output = syncbn.relu_bw_c_last(grad_output, saved_input, z, mean, inv_std, weight, bias) if isinstance(z, torch.Tensor) and ctx.needs_input_grad[1]: grad_z = grad_output.clone() # TODO: update kernel to not pre_divide by item_num if channel_last: sum_dy, sum_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn_c_last( grad_output, saved_input, mean, inv_std, weight) else: sum_dy, sum_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn( grad_output, saved_input, mean, inv_std, weight) # calculate grad_input if ctx.needs_input_grad[0]: if torch.distributed.is_initialized(): num_channels = sum_dy.shape[0] combined = torch.cat([sum_dy, sum_dy_xmu], dim=0) torch.distributed.all_reduce(combined, torch.distributed.ReduceOp.SUM, process_group, async_op=False) sum_dy, sum_dy_xmu = torch.split(combined, num_channels) if channel_last: grad_input = syncbn.batchnorm_backward_c_last( grad_output, saved_input, mean, inv_std, weight, sum_dy, sum_dy_xmu, count) else: grad_input = syncbn.batchnorm_backward(grad_output, saved_input, mean, inv_std, weight, sum_dy, sum_dy_xmu, count) if weight is None or not ctx.needs_input_grad[2]: grad_weight = None if weight is None or not ctx.needs_input_grad[3]: grad_bias = None return grad_input, grad_z, grad_weight, grad_bias, None, None, None, None, None, None, None, None