Пример #1
0
		def forward(ctx, input, weight, bias, running_mean, running_var, eps, momentum, training):
			mean, var = torch.batch_norm_update_stats(input, running_mean, running_var, momentum) if training else (running_mean, running_var)
			invstd = (var + eps).rsqrt_()
			output = torch.batch_norm_elemt(input, weight, bias, mean, invstd, 0, out = input)
			ctx.training = training
			ctx.save_for_backward(input, weight, bias, mean, invstd)
			ctx.mark_dirty(input)
			return input
Пример #2
0
    def forward(self, input, weight, bias, running_mean, running_var, eps,
                momentum, process_group, world_size):
        input = input.contiguous()
        count = torch.Tensor([input.numel() // input.size(1)]).to(input.device)

        # calculate mean/invstd for input.
        mean, invstd = torch.batch_norm_stats(input, eps)

        count_all = torch.empty(world_size,
                                1,
                                dtype=count.dtype,
                                device=count.device)
        mean_all = torch.empty(world_size,
                               mean.size(0),
                               dtype=mean.dtype,
                               device=mean.device)
        invstd_all = torch.empty(world_size,
                                 invstd.size(0),
                                 dtype=invstd.dtype,
                                 device=invstd.device)

        count_l = list(count_all.unbind(0))
        mean_l = list(mean_all.unbind(0))
        invstd_l = list(invstd_all.unbind(0))

        # using all_gather instead of all reduce so we can calculate count/mean/var in one go
        count_all_reduce = torch.distributed.all_gather(count_l,
                                                        count,
                                                        process_group,
                                                        async_op=True)
        mean_all_reduce = torch.distributed.all_gather(mean_l,
                                                       mean,
                                                       process_group,
                                                       async_op=True)
        invstd_all_reduce = torch.distributed.all_gather(invstd_l,
                                                         invstd,
                                                         process_group,
                                                         async_op=True)

        # wait on the async communication to finish
        count_all_reduce.wait()
        mean_all_reduce.wait()
        invstd_all_reduce.wait()

        # calcualte global mean & invstd
        mean, invstd = torch.batch_norm_gather_stats_with_counts(
            input, mean_all, invstd_all, running_mean, running_var, momentum,
            eps,
            count_all.view(-1).long().tolist())

        self.save_for_backward(input, weight, mean, invstd)
        self.process_group = process_group
        self.world_size = world_size

        # apply element-wise normalization
        out = torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps)
        return out
Пример #3
0
    def forward(self, input, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size):
        if not input.is_contiguous(memory_format=torch.channels_last):
            input = input.contiguous()
        weight = weight.contiguous()

        size = int(input.numel() // input.size(1))
        if size == 1 and world_size < 2:
            raise ValueError('Expected more than 1 value per channel when training, got input size {}'.format(size))

        # 计算单卡上的均值和方差
        # calculate mean/invstd for input.
        mean, invstd = torch.batch_norm_stats(input, eps)

        count = torch.full((1,), input.numel() // input.size(1),
                           dtype=mean.dtype,
                           device=mean.device)


        num_channels = input.shape[1]
        # C, C, 1 -> (2C + 1)
        combined = torch.cat([mean, invstd, count], dim=0)
        # world_size * (2C + 1)
        combined_list = [
            torch.empty_like(combined) for k in range(world_size)
        ]
        # Use allgather instead of allreduce since I don't trust in-place operations ..
        dist.all_gather(combined_list, combined, process_group, async_op=False)
        combined = torch.stack(combined_list, dim=0)
        # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1

        #同步各卡的数据,得到mean_all he invstd_all.
        mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1)

        #计算全局的mean 和 invstd
        # calculate global mean & invstd
        mean, invstd = torch.batch_norm_gather_stats_with_counts(
            input,
            mean_all,
            invstd_all,
            running_mean,
            running_var,
            momentum,
            eps,
            count_all.view(-1)
        )

        self.save_for_backward(input, weight, mean, invstd, count_all.to(torch.int32))
        self.process_group = process_group

        # apply element-wise normalization
        out = torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps)
        return out
Пример #4
0
		def backward(ctx, grad_output):
			assert ctx.training, 'Inplace BatchNorm supports only training mode, input gradients with running stats used are not yet supported'
			saved_output, weight, bias, mean, invstd = ctx.saved_tensors
			saved_input = torch.batch_norm_elemt(
				saved_output, invstd.reciprocal(), mean, bias, weight.reciprocal(), 0, out = saved_output
			)
			sum_dy, sum_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce(grad_output, saved_input, mean, invstd, weight, *ctx.needs_input_grad[:3])
			divisor = saved_input.numel() // saved_input.size(1)
			mean_dy = sum_dy.div_(divisor)
			mean_dy_xmu = sum_dy_xmu.div_(divisor)
			grad_input = torch.batch_norm_backward_elemt(
				grad_output, saved_input, mean, invstd, weight, mean_dy, mean_dy_xmu
			)
			return grad_input, grad_weight, grad_bias, None, None, None, None, None
Пример #5
0
    def forward(self, input, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size):
        input = input.contiguous()

        count = torch.empty(1,
                            dtype=running_mean.dtype,
                            device=input.device).fill_(input.numel() // input.size(1))

        # calculate mean/invstd for input.
        mean, invstd = torch.batch_norm_stats(input, eps)

        count_all = torch.empty(world_size, 1, dtype=count.dtype, device=count.device)
        mean_all = torch.empty(world_size, mean.size(0), dtype=mean.dtype, device=mean.device)
        invstd_all = torch.empty(world_size, invstd.size(0), dtype=invstd.dtype, device=invstd.device)

        count_l = list(count_all.unbind(0))
        mean_l = list(mean_all.unbind(0))
        invstd_l = list(invstd_all.unbind(0))

        # using all_gather instead of all reduce so we can calculate count/mean/var in one go
        count_all_reduce = torch.distributed.all_gather(count_l, count, process_group, async_op=True)
        mean_all_reduce = torch.distributed.all_gather(mean_l, mean, process_group, async_op=True)
        invstd_all_reduce = torch.distributed.all_gather(invstd_l, invstd, process_group, async_op=True)

        # wait on the async communication to finish
        count_all_reduce.wait()
        mean_all_reduce.wait()
        invstd_all_reduce.wait()

        size = count_all.view(-1).long().sum()
        if size == 1:
            raise ValueError('Expected more than 1 value per channel when training, got input size {}'.format(size))

        # calculate global mean & invstd
        mean, invstd = torch.batch_norm_gather_stats_with_counts(
            input,
            mean_all,
            invstd_all,
            running_mean,
            running_var,
            momentum,
            eps,
            count_all.view(-1)
        )

        self.save_for_backward(input, weight, mean, invstd, count_all)
        self.process_group = process_group

        # apply element-wise normalization
        out = torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps)
        return out
Пример #6
0
    def backward(self, grad_output):
        grad_output = grad_output.contiguous()
        saved_input, weight, mean, invstd, bias, count_tensor = self.saved_tensors

        # av: re-compute batch normalized out
        eps = 1e-5
        out = torch.batch_norm_elemt(saved_input, weight, bias, mean, invstd,
                                     eps)
        sigmoid_out = torch.sigmoid(out)
        grad_output *= (sigmoid_out * (1 + out * (1 - sigmoid_out)))
        # av: end

        grad_input = grad_weight = grad_bias = None
        process_group = self.process_group

        # calculate local stats as well as grad_weight / grad_bias
        sum_dy, sum_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce(
            grad_output, saved_input, mean, invstd, weight,
            self.needs_input_grad[0], self.needs_input_grad[1],
            self.needs_input_grad[2])

        if self.needs_input_grad[0]:
            # synchronizing stats used to calculate input gradient.
            # TODO: move div_ into batch_norm_backward_elemt kernel
            num_channels = sum_dy.shape[0]
            combined = torch.cat([sum_dy, sum_dy_xmu], dim=0)
            torch.distributed.all_reduce(combined,
                                         torch.distributed.ReduceOp.SUM,
                                         process_group,
                                         async_op=False)
            sum_dy, sum_dy_xmu = torch.split(combined, num_channels)

            divisor = count_tensor.sum()
            mean_dy = sum_dy / divisor
            mean_dy_xmu = sum_dy_xmu / divisor
            # backward pass for gradient calculation
            grad_input = torch.batch_norm_backward_elemt(
                grad_output, saved_input, mean, invstd, weight, mean_dy,
                mean_dy_xmu)

        # synchronizing of grad_weight / grad_bias is not needed as distributed
        # training would handle all reduce.
        if weight is None or not self.needs_input_grad[1]:
            grad_weight = None

        if weight is None or not self.needs_input_grad[2]:
            grad_bias = None

        return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
Пример #7
0
    def forward(self, input, weight, bias, running_mean, running_var, eps,
                momentum, process_group, world_size):
        input = input.contiguous()

        count = torch.empty(1, dtype=running_mean.dtype,
                            device=input.device).fill_(input.numel() //
                                                       input.size(1))

        # calculate mean/invstd for input.
        mean, invstd = torch.batch_norm_stats(input, eps)

        num_channels = input.shape[1]
        # C, C, 1 -> (2C + 1)
        combined = torch.cat([mean, invstd, count], dim=0)
        # world_size * (2C + 1)
        combined_list = [torch.empty_like(combined) for k in range(world_size)]
        # Use allgather instead of allreduce since I don't trust in-place operations ..
        dist.all_gather(combined_list, combined, async_op=False)
        combined = torch.stack(combined_list, dim=0)
        # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
        mean_all, invstd_all, count_all = torch.split(combined,
                                                      num_channels,
                                                      dim=1)

        size = count_all.view(-1).long().sum()
        if size == 1:
            raise ValueError(
                'Expected more than 1 value per channel when training, got input size {}'
                .format(size))

        # calculate global mean & invstd
        mean, invstd = torch.batch_norm_gather_stats_with_counts(
            input, mean_all, invstd_all, running_mean, running_var, momentum,
            eps, count_all.view(-1))

        self.save_for_backward(input, weight, mean, invstd, bias, count_all)
        self.process_group = process_group

        # apply element-wise normalization
        out = torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps)

        # av: apply swish
        assert eps == 1e-5, "I assumed below that eps is 1e-5"
        out = out * torch.sigmoid(out)
        # av: end

        return out
Пример #8
0
    def forward(self, input, weight, bias, running_mean, running_var, eps,
                momentum):
        input = input.contiguous()

        size = input.numel() // input.size(1)
        count = torch.tensor([size])

        # calculate mean/invstd for input.
        mean, invstd = torch.batch_norm_stats(input, eps)

        count_handle = allgather_async(count.unsqueeze(0),
                                       name='sync_batch_norm.count')
        mean_handle = allgather_async(mean.unsqueeze(0),
                                      name='sync_batch_norm.mean')
        invstd_handle = allgather_async(invstd.unsqueeze(0),
                                        name='sync_batch_norm.invstd')

        # wait on the async communication to finish
        count_all = synchronize(count_handle)
        mean_all = synchronize(mean_handle)
        invstd_all = synchronize(invstd_handle)

        if _SYNC_BN_V2:
            counts_for_bngswc = count_all.view(-1).float().to(input.device)
        else:
            # backwards compatibility
            counts_for_bngswc = count_all.view(-1).tolist()

        # calculate global mean & invstd
        mean, invstd = torch.batch_norm_gather_stats_with_counts(
            input, mean_all, invstd_all, running_mean, running_var, momentum,
            eps, counts_for_bngswc)

        self.save_for_backward(input, weight, mean, invstd, count_all)

        # apply element-wise normalization
        return torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps)
Пример #9
0
    def forward(self, input, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size):
        if not input.is_contiguous(memory_format=torch.channels_last):
            input = input.contiguous()
        if weight is not None:
            weight = weight.contiguous()

        size = int(input.numel() // input.size(1))
        if size == 1 and world_size < 2:
            raise ValueError('Expected more than 1 value per channel when training, got input size {}'.format(size))

        num_channels = input.shape[1]
        if input.numel() > 0:
            # calculate mean/invstd for input.
            mean, invstd = torch.batch_norm_stats(input, eps)

            count = torch.full(
                (1,),
                input.numel() // input.size(1),
                dtype=mean.dtype,
                device=mean.device
            )

            # C, C, 1 -> (2C + 1)
            combined = torch.cat([mean, invstd, count], dim=0)
        else:
            # for empty input, set stats and the count to zero. The stats with
            # zero count will be filtered out later when computing global mean
            # & invstd, but they still needs to participate the all_gather
            # collective communication to unblock other peer processes.
            combined = torch.zeros(
                2 * num_channels + 1,
                dtype=input.dtype,
                device=input.device
            )

        # Use allgather instead of allreduce because count could be different across
        # ranks, simple all reduce op can not give correct results.
        # batch_norm_gather_stats_with_counts calculates global mean & invstd based on
        # all gathered mean, invstd and count.
        # for nccl backend, use the optimized version of all gather.
        if process_group._get_backend_name() == 'nccl':
            # world_size * (2C + 1)
            combined_size = combined.numel()
            combined_flat = torch.empty(1,
                                        combined_size * world_size,
                                        dtype=combined.dtype,
                                        device=combined.device)
            dist._all_gather_base(combined_flat, combined, process_group, async_op=False)
            combined = torch.reshape(combined_flat, (world_size, combined_size))
            # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
            mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1)
        else:
            # world_size * (2C + 1)
            combined_list = [
                torch.empty_like(combined) for _ in range(world_size)
            ]
            dist.all_gather(combined_list, combined, process_group, async_op=False)
            combined = torch.stack(combined_list, dim=0)
            # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
            mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1)

        # remove stats from empty inputs
        mask = count_all.squeeze(-1) >= 1
        count_all = count_all[mask]
        mean_all = mean_all[mask]
        invstd_all = invstd_all[mask]

        # calculate global mean & invstd
        mean, invstd = torch.batch_norm_gather_stats_with_counts(
            input,
            mean_all,
            invstd_all,
            running_mean,
            running_var,
            momentum,
            eps,
            count_all.view(-1)
        )

        self.save_for_backward(input, weight, mean, invstd, count_all.to(torch.int32))
        self.process_group = process_group

        # apply element-wise normalization
        if input.numel() > 0:
            return torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps)
        else:
            return torch.empty_like(input)
Пример #10
0
    def forward(self, input, weight, bias, running_mean, running_var, eps,
                momentum, process_group, world_size, rank):
        input = input.contiguous()

        size = input.numel() // input.size(1)
        if size == 1:
            raise ValueError(
                'Expected more than 1 value per channel when training, got input size {}'
                .format(size))
        count = torch.Tensor([size]).to(input.device)

        # calculate mean/invstd for input.
        mean, invstd = torch.batch_norm_stats(input, eps)

        count_all = torch.empty(world_size,
                                1,
                                dtype=count.dtype,
                                device=count.device)
        mean_all = torch.empty(world_size,
                               mean.size(0),
                               dtype=mean.dtype,
                               device=mean.device)
        invstd_all = torch.empty(world_size,
                                 invstd.size(0),
                                 dtype=invstd.dtype,
                                 device=invstd.device)

        count_l = list(count_all.unbind(0))
        mean_l = list(mean_all.unbind(0))
        invstd_l = list(invstd_all.unbind(0))

        # using all_gather instead of all reduce so we can calculate count/mean/var in one go
        count_all_reduce = torch.distributed.all_gather(count_l,
                                                        count,
                                                        process_group,
                                                        async_op=True)
        mean_all_reduce = torch.distributed.all_gather(mean_l,
                                                       mean,
                                                       process_group,
                                                       async_op=True)
        invstd_all_reduce = torch.distributed.all_gather(invstd_l,
                                                         invstd,
                                                         process_group,
                                                         async_op=True)

        # wait on the async communication to finish
        count_all_reduce.wait()
        mean_all_reduce.wait()
        invstd_all_reduce.wait()

        ### uncomment to check result:
        # print('[%d]mean before shuffle:'%rank, mean_l[rank][0:4])

        # shuffle global mean & invstd
        new_rank = forward_shuffle(rank, world_size)
        count = count_l[new_rank]
        mean = mean_l[new_rank].view(1, -1)
        invstd = invstd_l[new_rank].view(1, -1)

        mean, invstd = torch.batch_norm_gather_stats(input, mean, invstd,
                                                     running_mean, running_var,
                                                     momentum, eps,
                                                     count.long().item())

        ### uncomment to check result:
        # print('[%d]mean after shuffle:'%rank, mean[0:4])

        self.save_for_backward(input, weight, mean, invstd)
        self.process_group = process_group
        self.world_size = world_size
        self.rank = rank

        # apply element-wise normalization
        out = torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps)
        return out
Пример #11
0
    def forward(self, input, weight, bias, running_mean, running_var, eps,
                momentum, process_group, world_size):
        if not input.is_contiguous(memory_format=torch.channels_last):
            input = input.contiguous()
        if weight is not None:
            weight = weight.contiguous()

        size = int(input.numel() // input.size(1))
        if size == 1 and world_size < 2:
            raise ValueError(
                'Expected more than 1 value per channel when training, got input size {}'
                .format(size))

        # calculate mean/invstd for input.
        mean, invstd = torch.batch_norm_stats(input, eps)

        count = torch.full((1, ),
                           input.numel() // input.size(1),
                           dtype=mean.dtype,
                           device=mean.device)

        num_channels = input.shape[1]
        # C, C, 1 -> (2C + 1)
        combined = torch.cat([mean, invstd, count], dim=0)
        # Use allgather instead of allreduce because count could be different across
        # ranks, simple all reduce op can not give correct results.
        # batch_norm_gather_stats_with_counts calculates global mean & invstd based on
        # all gathered mean, invstd and count.
        # for nccl backend, use the optimized version of all gather.
        if process_group._get_backend_name() == 'nccl':
            # world_size * (2C + 1)
            combined_size = combined.numel()
            combined_flat = torch.empty(1,
                                        combined_size * world_size,
                                        dtype=combined.dtype,
                                        device=combined.device)
            dist._all_gather_base(combined_flat,
                                  combined,
                                  process_group,
                                  async_op=False)
            combined = torch.reshape(combined_flat,
                                     (world_size, combined_size))
            # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
            mean_all, invstd_all, count_all = torch.split(combined,
                                                          num_channels,
                                                          dim=1)
        else:
            # world_size * (2C + 1)
            combined_list = [
                torch.empty_like(combined) for k in range(world_size)
            ]
            dist.all_gather(combined_list,
                            combined,
                            process_group,
                            async_op=False)
            combined = torch.stack(combined_list, dim=0)
            # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
            mean_all, invstd_all, count_all = torch.split(combined,
                                                          num_channels,
                                                          dim=1)

        # calculate global mean & invstd
        mean, invstd = torch.batch_norm_gather_stats_with_counts(
            input, mean_all, invstd_all, running_mean, running_var, momentum,
            eps, count_all.view(-1))

        self.save_for_backward(input, weight, mean, invstd,
                               count_all.to(torch.int32))
        self.process_group = process_group

        # apply element-wise normalization
        out = torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps)
        return out