def backward(self, grad_output): grad_output = grad_output.contiguous() saved_input, weight, mean, invstd = self.saved_tensors grad_input = grad_weight = grad_bias = None process_group = self.process_group world_size = self.world_size rank = self.rank # calculate local stats as well as grad_weight / grad_bias mean_dy, mean_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce( grad_output, saved_input, mean, invstd, weight, self.needs_input_grad[0], self.needs_input_grad[1], self.needs_input_grad[2]) if self.needs_input_grad[0]: # no need to communicate with others # backward pass for gradient calculation grad_input = torch.batch_norm_backward_elemt( grad_output, saved_input, mean, invstd, weight, mean_dy, mean_dy_xmu) # synchronizing of grad_weight / grad_bias is not needed as distributed # training would handle all reduce. if weight is None or not self.needs_input_grad[1]: grad_weight = None if weight is None or not self.needs_input_grad[2]: grad_bias = None return grad_input, grad_weight, grad_bias, None, None, None, None, None, None, None
def backward(self, grad_output): grad_output = grad_output.contiguous() saved_input, weight, mean, invstd, count_tensor = self.saved_tensors grad_input = grad_weight = grad_bias = None process_group = self.process_group # calculate local stats as well as grad_weight / grad_bias sum_dy, sum_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce( grad_output, saved_input, mean, invstd, weight, self.needs_input_grad[0], self.needs_input_grad[1], self.needs_input_grad[2] ) if self.needs_input_grad[0]: # synchronizing stats used to calculate input gradient. # TODO: move div_ into batch_norm_backward_elemt kernel sum_dy_all_reduce = torch.distributed.all_reduce( sum_dy, torch.distributed.ReduceOp.SUM, process_group, async_op=True) sum_dy_xmu_all_reduce = torch.distributed.all_reduce( sum_dy_xmu, torch.distributed.ReduceOp.SUM, process_group, async_op=True) # wait on the async communication to finish sum_dy_all_reduce.wait() sum_dy_xmu_all_reduce.wait() divisor = count_tensor.sum() mean_dy = sum_dy / divisor mean_dy_xmu = sum_dy_xmu / divisor # backward pass for gradient calculation grad_input = torch.batch_norm_backward_elemt( grad_output, saved_input, mean, invstd, weight, mean_dy, mean_dy_xmu ) # synchronizing of grad_weight / grad_bias is not needed as distributed # training would handle all reduce. if weight is None or not self.needs_input_grad[1]: grad_weight = None if weight is None or not self.needs_input_grad[2]: grad_bias = None return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
def backward(self, grad_output): if not grad_output.is_contiguous(memory_format=torch.channels_last): grad_output = grad_output.contiguous() saved_input, weight, mean, invstd, count_tensor = self.saved_tensors grad_input = grad_weight = grad_bias = None process_group = self.process_group # calculate local stats as well as grad_weight / grad_bias sum_dy, sum_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce( grad_output, saved_input, mean, invstd, weight, self.needs_input_grad[0], self.needs_input_grad[1], self.needs_input_grad[2] ) if self.needs_input_grad[0]: # synchronizing stats used to calculate input gradient. num_channels = sum_dy.shape[0] combined = torch.cat([sum_dy, sum_dy_xmu], dim=0) #all reduce, 计算梯度之和 torch.distributed.all_reduce( combined, torch.distributed.ReduceOp.SUM, process_group, async_op=False) #根据总的size, 对梯度求平均 sum_dy, sum_dy_xmu = torch.split(combined, num_channels) # backward pass for gradient calculation grad_input = torch.batch_norm_backward_elemt( grad_output, saved_input, mean, invstd, weight, sum_dy, sum_dy_xmu, count_tensor ) # synchronizing of grad_weight / grad_bias is not needed as distributed # training would handle all reduce. if weight is None or not self.needs_input_grad[1]: grad_weight = None if weight is None or not self.needs_input_grad[2]: grad_bias = None return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
def backward(ctx, grad_output): assert ctx.training, 'Inplace BatchNorm supports only training mode, input gradients with running stats used are not yet supported' saved_output, weight, bias, mean, invstd = ctx.saved_tensors saved_input = torch.batch_norm_elemt( saved_output, invstd.reciprocal(), mean, bias, weight.reciprocal(), 0, out = saved_output ) sum_dy, sum_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce(grad_output, saved_input, mean, invstd, weight, *ctx.needs_input_grad[:3]) divisor = saved_input.numel() // saved_input.size(1) mean_dy = sum_dy.div_(divisor) mean_dy_xmu = sum_dy_xmu.div_(divisor) grad_input = torch.batch_norm_backward_elemt( grad_output, saved_input, mean, invstd, weight, mean_dy, mean_dy_xmu ) return grad_input, grad_weight, grad_bias, None, None, None, None, None
def backward(self, grad_output): grad_output = grad_output.contiguous() saved_input, weight, mean, invstd, count_all = self.saved_tensors need_input_grad, need_weight_grad, need_bias_grad = self.needs_input_grad[ 0:3] # calculate local stats as well as grad_weight / grad_bias sum_dy, sum_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce( grad_output, saved_input, mean, invstd, weight, need_input_grad, need_weight_grad, need_bias_grad) if need_input_grad: # synchronizing stats used to calculate input gradient. sum_dy_handle = allreduce_async(sum_dy, op=Sum, name='sync_batch_norm.sum_dy') sum_dy_xmu_handle = allreduce_async( sum_dy_xmu, op=Sum, name='sync_batch_norm.sum_dy_xmu') # wait on the async communication to finish sum_dy = synchronize(sum_dy_handle) sum_dy_xmu = synchronize(sum_dy_xmu_handle) if _SYNC_BN_V2 or _SYNC_BN_V3: count_all_sum = count_all.sum() mean_dy = sum_dy / count_all_sum mean_dy_xmu = sum_dy_xmu / count_all_sum else: # before 1.5.0, sum_dy was sum of means from every worker, so we just # need to divide it by number of workers mean_dy = sum_dy / size() mean_dy_xmu = sum_dy_xmu / size() # backward pass for gradient calculation grad_input = torch.batch_norm_backward_elemt( grad_output, saved_input, mean, invstd, weight, mean_dy, mean_dy_xmu) else: grad_input = None # synchronizing of grad_weight / grad_bias is not needed as distributed # training would handle all reduce. if weight is None or not need_weight_grad: grad_weight = None if weight is None or not need_bias_grad: grad_bias = None return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
def backward(self, grad_output): grad_output = grad_output.contiguous() saved_input, weight, mean, invstd, bias, count_tensor = self.saved_tensors # av: re-compute batch normalized out eps = 1e-5 out = torch.batch_norm_elemt(saved_input, weight, bias, mean, invstd, eps) sigmoid_out = torch.sigmoid(out) grad_output *= (sigmoid_out * (1 + out * (1 - sigmoid_out))) # av: end grad_input = grad_weight = grad_bias = None process_group = self.process_group # calculate local stats as well as grad_weight / grad_bias sum_dy, sum_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce( grad_output, saved_input, mean, invstd, weight, self.needs_input_grad[0], self.needs_input_grad[1], self.needs_input_grad[2]) if self.needs_input_grad[0]: # synchronizing stats used to calculate input gradient. # TODO: move div_ into batch_norm_backward_elemt kernel num_channels = sum_dy.shape[0] combined = torch.cat([sum_dy, sum_dy_xmu], dim=0) torch.distributed.all_reduce(combined, torch.distributed.ReduceOp.SUM, process_group, async_op=False) sum_dy, sum_dy_xmu = torch.split(combined, num_channels) divisor = count_tensor.sum() mean_dy = sum_dy / divisor mean_dy_xmu = sum_dy_xmu / divisor # backward pass for gradient calculation grad_input = torch.batch_norm_backward_elemt( grad_output, saved_input, mean, invstd, weight, mean_dy, mean_dy_xmu) # synchronizing of grad_weight / grad_bias is not needed as distributed # training would handle all reduce. if weight is None or not self.needs_input_grad[1]: grad_weight = None if weight is None or not self.needs_input_grad[2]: grad_bias = None return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
def backward(self, grad_output): if not grad_output.is_contiguous(memory_format=torch.channels_last): grad_output = grad_output.contiguous() saved_input, weight, mean, invstd, count_tensor = self.saved_tensors grad_input = grad_weight = grad_bias = None process_group = self.process_group if saved_input.numel() > 0: # calculate local stats as well as grad_weight / grad_bias sum_dy, sum_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce( grad_output, saved_input, mean, invstd, weight, self.needs_input_grad[0], self.needs_input_grad[1], self.needs_input_grad[2] ) if self.needs_input_grad[0]: # synchronizing stats used to calculate input gradient. num_channels = sum_dy.shape[0] combined = torch.cat([sum_dy, sum_dy_xmu], dim=0) torch.distributed.all_reduce( combined, torch.distributed.ReduceOp.SUM, process_group, async_op=False) sum_dy, sum_dy_xmu = torch.split(combined, num_channels) # backward pass for gradient calculation grad_input = torch.batch_norm_backward_elemt( grad_output, saved_input, mean, invstd, weight, sum_dy, sum_dy_xmu, count_tensor ) # synchronizing of grad_weight / grad_bias is not needed as distributed # training would handle all reduce. if weight is None or not self.needs_input_grad[1]: grad_weight = None if weight is None or not self.needs_input_grad[2]: grad_bias = None else: # This process got an empty input tensor in the forward pass. # Although this process can directly set grad_input as an empty # tensor of zeros, it still needs to participate in the collective # communication to unblock its peers, as other peer processes might # have recieved non-empty inputs. num_channels = saved_input.shape[1] if self.needs_input_grad[0]: # launch all_reduce to unblock other peer processes combined = torch.zeros( 2 * num_channels, dtype=saved_input.dtype, device=saved_input.device ) torch.distributed.all_reduce( combined, torch.distributed.ReduceOp.SUM, process_group, async_op=False) # Leave grad_input, grad_weight and grad_bias as None, which will be # interpreted by the autograd engine as Tensors full of zeros. return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
def backward(self, grad_output): grad_output = grad_output.contiguous() saved_input, weight, mean, invstd, count_all = self.saved_tensors need_input_grad, need_weight_grad, need_bias_grad = self.needs_input_grad[ 0:3] # calculate local stats as well as grad_weight / grad_bias sum_dy, sum_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce( grad_output, saved_input, mean, invstd, weight, need_input_grad, need_weight_grad, need_bias_grad) if need_input_grad: # synchronizing stats used to calculate input gradient. sum_dy_handle = allreduce_async(sum_dy, op=Sum, name='sync_batch_norm.sum_dy') sum_dy_xmu_handle = allreduce_async( sum_dy_xmu, op=Sum, name='sync_batch_norm.sum_dy_xmu') # wait on the async communication to finish sum_dy = synchronize(sum_dy_handle) sum_dy_xmu = synchronize(sum_dy_xmu_handle) if _SYNC_BN_V4: # from 1.9.0 on we need a count tensor on all devices # count_all is calculated as total count across all ranks in forward function count_all = count_all.to(dtype=torch.int, device=grad_output.device) elif _SYNC_BN_V2 or _SYNC_BN_V3: # before 1.9.0 we need the count as an integer to compute means values count = count_all.sum() else: # before 1.5.0, sum_dy was sum of means from every worker, so we just # need to divide it by number of workers count = size() # backward pass for gradient calculation # we are calling into a non-public undocumented function which broke moving to 1.9.0 # https://github.com/pytorch/pytorch/issues/57900 if _SYNC_BN_V4: # from 1.9.0 on, sums and count parameters expected grad_input = torch.batch_norm_backward_elemt( grad_output, saved_input, mean, invstd, weight, sum_dy, sum_dy_xmu, count_all) else: # before 1.9.0, mean parameters expected, not sums and count grad_input = torch.batch_norm_backward_elemt( grad_output, saved_input, mean, invstd, weight, sum_dy / count, sum_dy_xmu / count) else: grad_input = None # synchronizing of grad_weight / grad_bias is not needed as distributed # training would handle all reduce. if weight is None or not need_weight_grad: grad_weight = None if weight is None or not need_bias_grad: grad_bias = None return grad_input, grad_weight, grad_bias, None, None, None, None, None, None