Exemple #1
0
    def backward(self, grad_output):
        grad_output = grad_output.contiguous()
        saved_input, weight, mean, invstd = self.saved_tensors
        grad_input = grad_weight = grad_bias = None
        process_group = self.process_group
        world_size = self.world_size
        rank = self.rank

        # calculate local stats as well as grad_weight / grad_bias
        mean_dy, mean_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce(
            grad_output, saved_input, mean, invstd, weight,
            self.needs_input_grad[0], self.needs_input_grad[1],
            self.needs_input_grad[2])

        if self.needs_input_grad[0]:
            # no need to communicate with others
            # backward pass for gradient calculation
            grad_input = torch.batch_norm_backward_elemt(
                grad_output, saved_input, mean, invstd, weight, mean_dy,
                mean_dy_xmu)

        # synchronizing of grad_weight / grad_bias is not needed as distributed
        # training would handle all reduce.
        if weight is None or not self.needs_input_grad[1]:
            grad_weight = None

        if weight is None or not self.needs_input_grad[2]:
            grad_bias = None

        return grad_input, grad_weight, grad_bias, None, None, None, None, None, None, None
Exemple #2
0
    def backward(self, grad_output):
        grad_output = grad_output.contiguous()
        saved_input, weight, mean, invstd, count_tensor = self.saved_tensors
        grad_input = grad_weight = grad_bias = None
        process_group = self.process_group

        # calculate local stats as well as grad_weight / grad_bias
        sum_dy, sum_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce(
            grad_output,
            saved_input,
            mean,
            invstd,
            weight,
            self.needs_input_grad[0],
            self.needs_input_grad[1],
            self.needs_input_grad[2]
        )

        if self.needs_input_grad[0]:
            # synchronizing stats used to calculate input gradient.
            # TODO: move div_ into batch_norm_backward_elemt kernel
            sum_dy_all_reduce = torch.distributed.all_reduce(
                sum_dy, torch.distributed.ReduceOp.SUM, process_group, async_op=True)
            sum_dy_xmu_all_reduce = torch.distributed.all_reduce(
                sum_dy_xmu, torch.distributed.ReduceOp.SUM, process_group, async_op=True)

            # wait on the async communication to finish
            sum_dy_all_reduce.wait()
            sum_dy_xmu_all_reduce.wait()

            divisor = count_tensor.sum()
            mean_dy = sum_dy / divisor
            mean_dy_xmu = sum_dy_xmu / divisor
            # backward pass for gradient calculation
            grad_input = torch.batch_norm_backward_elemt(
                grad_output,
                saved_input,
                mean,
                invstd,
                weight,
                mean_dy,
                mean_dy_xmu
            )

        # synchronizing of grad_weight / grad_bias is not needed as distributed
        # training would handle all reduce.
        if weight is None or not self.needs_input_grad[1]:
            grad_weight = None

        if weight is None or not self.needs_input_grad[2]:
            grad_bias = None

        return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
Exemple #3
0
    def backward(self, grad_output):
        if not grad_output.is_contiguous(memory_format=torch.channels_last):
            grad_output = grad_output.contiguous()
        saved_input, weight, mean, invstd, count_tensor = self.saved_tensors
        grad_input = grad_weight = grad_bias = None
        process_group = self.process_group

        # calculate local stats as well as grad_weight / grad_bias
        sum_dy, sum_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce(
            grad_output,
            saved_input,
            mean,
            invstd,
            weight,
            self.needs_input_grad[0],
            self.needs_input_grad[1],
            self.needs_input_grad[2]
        )

        if self.needs_input_grad[0]:
            # synchronizing stats used to calculate input gradient.
            num_channels = sum_dy.shape[0]
            combined = torch.cat([sum_dy, sum_dy_xmu], dim=0)
            #all reduce, 计算梯度之和
            torch.distributed.all_reduce(
                combined, torch.distributed.ReduceOp.SUM, process_group, async_op=False)
            
            #根据总的size, 对梯度求平均
            sum_dy, sum_dy_xmu = torch.split(combined, num_channels)

            # backward pass for gradient calculation
            grad_input = torch.batch_norm_backward_elemt(
                grad_output,
                saved_input,
                mean,
                invstd,
                weight,
                sum_dy,
                sum_dy_xmu,
                count_tensor
            )

        # synchronizing of grad_weight / grad_bias is not needed as distributed
        # training would handle all reduce.
        if weight is None or not self.needs_input_grad[1]:
            grad_weight = None

        if weight is None or not self.needs_input_grad[2]:
            grad_bias = None

        return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
Exemple #4
0
		def backward(ctx, grad_output):
			assert ctx.training, 'Inplace BatchNorm supports only training mode, input gradients with running stats used are not yet supported'
			saved_output, weight, bias, mean, invstd = ctx.saved_tensors
			saved_input = torch.batch_norm_elemt(
				saved_output, invstd.reciprocal(), mean, bias, weight.reciprocal(), 0, out = saved_output
			)
			sum_dy, sum_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce(grad_output, saved_input, mean, invstd, weight, *ctx.needs_input_grad[:3])
			divisor = saved_input.numel() // saved_input.size(1)
			mean_dy = sum_dy.div_(divisor)
			mean_dy_xmu = sum_dy_xmu.div_(divisor)
			grad_input = torch.batch_norm_backward_elemt(
				grad_output, saved_input, mean, invstd, weight, mean_dy, mean_dy_xmu
			)
			return grad_input, grad_weight, grad_bias, None, None, None, None, None
    def backward(self, grad_output):
        grad_output = grad_output.contiguous()
        saved_input, weight, mean, invstd, count_all = self.saved_tensors
        need_input_grad, need_weight_grad, need_bias_grad = self.needs_input_grad[
            0:3]

        # calculate local stats as well as grad_weight / grad_bias
        sum_dy, sum_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce(
            grad_output, saved_input, mean, invstd, weight, need_input_grad,
            need_weight_grad, need_bias_grad)

        if need_input_grad:
            # synchronizing stats used to calculate input gradient.
            sum_dy_handle = allreduce_async(sum_dy,
                                            op=Sum,
                                            name='sync_batch_norm.sum_dy')
            sum_dy_xmu_handle = allreduce_async(
                sum_dy_xmu, op=Sum, name='sync_batch_norm.sum_dy_xmu')

            # wait on the async communication to finish
            sum_dy = synchronize(sum_dy_handle)
            sum_dy_xmu = synchronize(sum_dy_xmu_handle)

            if _SYNC_BN_V2 or _SYNC_BN_V3:
                count_all_sum = count_all.sum()
                mean_dy = sum_dy / count_all_sum
                mean_dy_xmu = sum_dy_xmu / count_all_sum
            else:
                # before 1.5.0, sum_dy was sum of means from every worker, so we just
                # need to divide it by number of workers
                mean_dy = sum_dy / size()
                mean_dy_xmu = sum_dy_xmu / size()

            # backward pass for gradient calculation
            grad_input = torch.batch_norm_backward_elemt(
                grad_output, saved_input, mean, invstd, weight, mean_dy,
                mean_dy_xmu)
        else:
            grad_input = None

        # synchronizing of grad_weight / grad_bias is not needed as distributed
        # training would handle all reduce.
        if weight is None or not need_weight_grad:
            grad_weight = None

        if weight is None or not need_bias_grad:
            grad_bias = None

        return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
Exemple #6
0
    def backward(self, grad_output):
        grad_output = grad_output.contiguous()
        saved_input, weight, mean, invstd, bias, count_tensor = self.saved_tensors

        # av: re-compute batch normalized out
        eps = 1e-5
        out = torch.batch_norm_elemt(saved_input, weight, bias, mean, invstd,
                                     eps)
        sigmoid_out = torch.sigmoid(out)
        grad_output *= (sigmoid_out * (1 + out * (1 - sigmoid_out)))
        # av: end

        grad_input = grad_weight = grad_bias = None
        process_group = self.process_group

        # calculate local stats as well as grad_weight / grad_bias
        sum_dy, sum_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce(
            grad_output, saved_input, mean, invstd, weight,
            self.needs_input_grad[0], self.needs_input_grad[1],
            self.needs_input_grad[2])

        if self.needs_input_grad[0]:
            # synchronizing stats used to calculate input gradient.
            # TODO: move div_ into batch_norm_backward_elemt kernel
            num_channels = sum_dy.shape[0]
            combined = torch.cat([sum_dy, sum_dy_xmu], dim=0)
            torch.distributed.all_reduce(combined,
                                         torch.distributed.ReduceOp.SUM,
                                         process_group,
                                         async_op=False)
            sum_dy, sum_dy_xmu = torch.split(combined, num_channels)

            divisor = count_tensor.sum()
            mean_dy = sum_dy / divisor
            mean_dy_xmu = sum_dy_xmu / divisor
            # backward pass for gradient calculation
            grad_input = torch.batch_norm_backward_elemt(
                grad_output, saved_input, mean, invstd, weight, mean_dy,
                mean_dy_xmu)

        # synchronizing of grad_weight / grad_bias is not needed as distributed
        # training would handle all reduce.
        if weight is None or not self.needs_input_grad[1]:
            grad_weight = None

        if weight is None or not self.needs_input_grad[2]:
            grad_bias = None

        return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
Exemple #7
0
    def backward(self, grad_output):
        if not grad_output.is_contiguous(memory_format=torch.channels_last):
            grad_output = grad_output.contiguous()
        saved_input, weight, mean, invstd, count_tensor = self.saved_tensors
        grad_input = grad_weight = grad_bias = None
        process_group = self.process_group

        if saved_input.numel() > 0:
            # calculate local stats as well as grad_weight / grad_bias
            sum_dy, sum_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce(
                grad_output,
                saved_input,
                mean,
                invstd,
                weight,
                self.needs_input_grad[0],
                self.needs_input_grad[1],
                self.needs_input_grad[2]
            )

            if self.needs_input_grad[0]:
                # synchronizing stats used to calculate input gradient.
                num_channels = sum_dy.shape[0]
                combined = torch.cat([sum_dy, sum_dy_xmu], dim=0)
                torch.distributed.all_reduce(
                    combined, torch.distributed.ReduceOp.SUM, process_group, async_op=False)
                sum_dy, sum_dy_xmu = torch.split(combined, num_channels)

                # backward pass for gradient calculation
                grad_input = torch.batch_norm_backward_elemt(
                    grad_output,
                    saved_input,
                    mean,
                    invstd,
                    weight,
                    sum_dy,
                    sum_dy_xmu,
                    count_tensor
                )
            # synchronizing of grad_weight / grad_bias is not needed as distributed
            # training would handle all reduce.
            if weight is None or not self.needs_input_grad[1]:
                grad_weight = None

            if weight is None or not self.needs_input_grad[2]:
                grad_bias = None
        else:
            # This process got an empty input tensor in the forward pass.
            # Although this process can directly set grad_input as an empty
            # tensor of zeros, it still needs to participate in the collective
            # communication to unblock its peers, as other peer processes might
            # have recieved non-empty inputs.
            num_channels = saved_input.shape[1]
            if self.needs_input_grad[0]:
                # launch all_reduce to unblock other peer processes
                combined = torch.zeros(
                    2 * num_channels,
                    dtype=saved_input.dtype,
                    device=saved_input.device
                )
                torch.distributed.all_reduce(
                    combined, torch.distributed.ReduceOp.SUM, process_group, async_op=False)

            # Leave grad_input, grad_weight and grad_bias as None, which will be
            # interpreted by the autograd engine as Tensors full of zeros.

        return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
Exemple #8
0
    def backward(self, grad_output):
        grad_output = grad_output.contiguous()
        saved_input, weight, mean, invstd, count_all = self.saved_tensors
        need_input_grad, need_weight_grad, need_bias_grad = self.needs_input_grad[
            0:3]

        # calculate local stats as well as grad_weight / grad_bias
        sum_dy, sum_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce(
            grad_output, saved_input, mean, invstd, weight, need_input_grad,
            need_weight_grad, need_bias_grad)

        if need_input_grad:
            # synchronizing stats used to calculate input gradient.
            sum_dy_handle = allreduce_async(sum_dy,
                                            op=Sum,
                                            name='sync_batch_norm.sum_dy')
            sum_dy_xmu_handle = allreduce_async(
                sum_dy_xmu, op=Sum, name='sync_batch_norm.sum_dy_xmu')

            # wait on the async communication to finish
            sum_dy = synchronize(sum_dy_handle)
            sum_dy_xmu = synchronize(sum_dy_xmu_handle)

            if _SYNC_BN_V4:
                # from 1.9.0 on we need a count tensor on all devices
                # count_all is calculated as total count across all ranks in forward function
                count_all = count_all.to(dtype=torch.int,
                                         device=grad_output.device)
            elif _SYNC_BN_V2 or _SYNC_BN_V3:
                # before 1.9.0 we need the count as an integer to compute means values
                count = count_all.sum()
            else:
                # before 1.5.0, sum_dy was sum of means from every worker, so we just
                # need to divide it by number of workers
                count = size()

            # backward pass for gradient calculation
            # we are calling into a non-public undocumented function which broke moving to 1.9.0
            # https://github.com/pytorch/pytorch/issues/57900
            if _SYNC_BN_V4:
                # from 1.9.0 on, sums and count parameters expected
                grad_input = torch.batch_norm_backward_elemt(
                    grad_output, saved_input, mean, invstd, weight, sum_dy,
                    sum_dy_xmu, count_all)
            else:
                # before 1.9.0, mean parameters expected, not sums and count
                grad_input = torch.batch_norm_backward_elemt(
                    grad_output, saved_input, mean, invstd, weight,
                    sum_dy / count, sum_dy_xmu / count)
        else:
            grad_input = None

        # synchronizing of grad_weight / grad_bias is not needed as distributed
        # training would handle all reduce.
        if weight is None or not need_weight_grad:
            grad_weight = None

        if weight is None or not need_bias_grad:
            grad_bias = None

        return grad_input, grad_weight, grad_bias, None, None, None, None, None, None