def forward(self, input): if not self.training or not dist.is_initialized(): bn = (input - self.running_mean.view(1, self.running_mean.shape[0], 1, 1)) / \ (torch.sqrt(self.running_var.view(1, self.running_var.shape[0], 1, 1) + self.eps)) # print(self.weight.shape, self.bias.shape) return bn.mul(self.weight.view(1, self.weight.shape[0], 1, 1)).add(self.bias.view(1, self.bias.shape[0], 1, 1)) shard_mean, shard_invstd = torch.batch_norm_stats(input, self.eps) shard_vars = (1. / shard_invstd) ** 2 - self.eps shard_square_of_mean = torch.mul(shard_mean, shard_mean) shard_mean_of_square = shard_vars + shard_square_of_mean group_mean = shard_mean.clone().detach() self._reduce_avg(group_mean) group_mean_of_square = shard_mean_of_square.clone().detach() self._reduce_avg(group_mean_of_square) group_vars = group_mean_of_square - torch.mul(group_mean, group_mean) group_mean = group_mean.detach() group_vars = group_vars.detach() # print(self.running_mean.shape, self.running_var.shape) self.running_mean.mul_(1. - self.momentum).add_(group_mean.mul(self.momentum)) self.running_var.mul_(1. - self.momentum).add_(group_vars.mul(self.momentum)) self.num_batches_tracked.add_(1) # print(input.shape, group_mean.view(1, group_mean.shape[0], 1, 1).shape, group_vars.view(1, group_vars.shape[0], 1, 1).shape, self.eps) bn = (input - group_mean.view(1, group_mean.shape[0], 1, 1)) / (torch.sqrt(group_vars.view(1, group_vars.shape[0], 1, 1) + self.eps)) # print(self.weight.shape, self.bias.shape) return bn.mul(self.weight.view(1, self.weight.shape[0], 1, 1)).add(self.bias.view(1, self.bias.shape[0], 1, 1))
def forward(self, input, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size): input = input.contiguous() count = torch.Tensor([input.numel() // input.size(1)]).to(input.device) # calculate mean/invstd for input. mean, invstd = torch.batch_norm_stats(input, eps) count_all = torch.empty(world_size, 1, dtype=count.dtype, device=count.device) mean_all = torch.empty(world_size, mean.size(0), dtype=mean.dtype, device=mean.device) invstd_all = torch.empty(world_size, invstd.size(0), dtype=invstd.dtype, device=invstd.device) count_l = list(count_all.unbind(0)) mean_l = list(mean_all.unbind(0)) invstd_l = list(invstd_all.unbind(0)) # using all_gather instead of all reduce so we can calculate count/mean/var in one go count_all_reduce = torch.distributed.all_gather(count_l, count, process_group, async_op=True) mean_all_reduce = torch.distributed.all_gather(mean_l, mean, process_group, async_op=True) invstd_all_reduce = torch.distributed.all_gather(invstd_l, invstd, process_group, async_op=True) # wait on the async communication to finish count_all_reduce.wait() mean_all_reduce.wait() invstd_all_reduce.wait() # calcualte global mean & invstd mean, invstd = torch.batch_norm_gather_stats_with_counts( input, mean_all, invstd_all, running_mean, running_var, momentum, eps, count_all.view(-1).long().tolist()) self.save_for_backward(input, weight, mean, invstd) self.process_group = process_group self.world_size = world_size # apply element-wise normalization out = torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps) return out
def forward(self, input, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size): if not input.is_contiguous(memory_format=torch.channels_last): input = input.contiguous() weight = weight.contiguous() size = int(input.numel() // input.size(1)) if size == 1 and world_size < 2: raise ValueError('Expected more than 1 value per channel when training, got input size {}'.format(size)) # 计算单卡上的均值和方差 # calculate mean/invstd for input. mean, invstd = torch.batch_norm_stats(input, eps) count = torch.full((1,), input.numel() // input.size(1), dtype=mean.dtype, device=mean.device) num_channels = input.shape[1] # C, C, 1 -> (2C + 1) combined = torch.cat([mean, invstd, count], dim=0) # world_size * (2C + 1) combined_list = [ torch.empty_like(combined) for k in range(world_size) ] # Use allgather instead of allreduce since I don't trust in-place operations .. dist.all_gather(combined_list, combined, process_group, async_op=False) combined = torch.stack(combined_list, dim=0) # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1 #同步各卡的数据,得到mean_all he invstd_all. mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1) #计算全局的mean 和 invstd # calculate global mean & invstd mean, invstd = torch.batch_norm_gather_stats_with_counts( input, mean_all, invstd_all, running_mean, running_var, momentum, eps, count_all.view(-1) ) self.save_for_backward(input, weight, mean, invstd, count_all.to(torch.int32)) self.process_group = process_group # apply element-wise normalization out = torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps) return out
def forward(self, input, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size): input = input.contiguous() count = torch.empty(1, dtype=running_mean.dtype, device=input.device).fill_(input.numel() // input.size(1)) # calculate mean/invstd for input. mean, invstd = torch.batch_norm_stats(input, eps) count_all = torch.empty(world_size, 1, dtype=count.dtype, device=count.device) mean_all = torch.empty(world_size, mean.size(0), dtype=mean.dtype, device=mean.device) invstd_all = torch.empty(world_size, invstd.size(0), dtype=invstd.dtype, device=invstd.device) count_l = list(count_all.unbind(0)) mean_l = list(mean_all.unbind(0)) invstd_l = list(invstd_all.unbind(0)) # using all_gather instead of all reduce so we can calculate count/mean/var in one go count_all_reduce = torch.distributed.all_gather(count_l, count, process_group, async_op=True) mean_all_reduce = torch.distributed.all_gather(mean_l, mean, process_group, async_op=True) invstd_all_reduce = torch.distributed.all_gather(invstd_l, invstd, process_group, async_op=True) # wait on the async communication to finish count_all_reduce.wait() mean_all_reduce.wait() invstd_all_reduce.wait() size = count_all.view(-1).long().sum() if size == 1: raise ValueError('Expected more than 1 value per channel when training, got input size {}'.format(size)) # calculate global mean & invstd mean, invstd = torch.batch_norm_gather_stats_with_counts( input, mean_all, invstd_all, running_mean, running_var, momentum, eps, count_all.view(-1) ) self.save_for_backward(input, weight, mean, invstd, count_all) self.process_group = process_group # apply element-wise normalization out = torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps) return out
def forward(self, input, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size): input = input.contiguous() count = torch.empty(1, dtype=running_mean.dtype, device=input.device).fill_(input.numel() // input.size(1)) # calculate mean/invstd for input. mean, invstd = torch.batch_norm_stats(input, eps) num_channels = input.shape[1] # C, C, 1 -> (2C + 1) combined = torch.cat([mean, invstd, count], dim=0) # world_size * (2C + 1) combined_list = [torch.empty_like(combined) for k in range(world_size)] # Use allgather instead of allreduce since I don't trust in-place operations .. dist.all_gather(combined_list, combined, async_op=False) combined = torch.stack(combined_list, dim=0) # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1 mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1) size = count_all.view(-1).long().sum() if size == 1: raise ValueError( 'Expected more than 1 value per channel when training, got input size {}' .format(size)) # calculate global mean & invstd mean, invstd = torch.batch_norm_gather_stats_with_counts( input, mean_all, invstd_all, running_mean, running_var, momentum, eps, count_all.view(-1)) self.save_for_backward(input, weight, mean, invstd, bias, count_all) self.process_group = process_group # apply element-wise normalization out = torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps) # av: apply swish assert eps == 1e-5, "I assumed below that eps is 1e-5" out = out * torch.sigmoid(out) # av: end return out
def forward(self, input, weight, bias, running_mean, running_var, eps, momentum): input = input.contiguous() size = input.numel() // input.size(1) count = torch.tensor([size]) # calculate mean/invstd for input. mean, invstd = torch.batch_norm_stats(input, eps) count_handle = allgather_async(count.unsqueeze(0), name='sync_batch_norm.count') mean_handle = allgather_async(mean.unsqueeze(0), name='sync_batch_norm.mean') invstd_handle = allgather_async(invstd.unsqueeze(0), name='sync_batch_norm.invstd') # wait on the async communication to finish count_all = synchronize(count_handle) mean_all = synchronize(mean_handle) invstd_all = synchronize(invstd_handle) if _SYNC_BN_V2: counts_for_bngswc = count_all.view(-1).float().to(input.device) else: # backwards compatibility counts_for_bngswc = count_all.view(-1).tolist() # calculate global mean & invstd mean, invstd = torch.batch_norm_gather_stats_with_counts( input, mean_all, invstd_all, running_mean, running_var, momentum, eps, counts_for_bngswc) self.save_for_backward(input, weight, mean, invstd, count_all) # apply element-wise normalization return torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps)
def forward(self, input, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size): if not input.is_contiguous(memory_format=torch.channels_last): input = input.contiguous() if weight is not None: weight = weight.contiguous() size = int(input.numel() // input.size(1)) if size == 1 and world_size < 2: raise ValueError('Expected more than 1 value per channel when training, got input size {}'.format(size)) num_channels = input.shape[1] if input.numel() > 0: # calculate mean/invstd for input. mean, invstd = torch.batch_norm_stats(input, eps) count = torch.full( (1,), input.numel() // input.size(1), dtype=mean.dtype, device=mean.device ) # C, C, 1 -> (2C + 1) combined = torch.cat([mean, invstd, count], dim=0) else: # for empty input, set stats and the count to zero. The stats with # zero count will be filtered out later when computing global mean # & invstd, but they still needs to participate the all_gather # collective communication to unblock other peer processes. combined = torch.zeros( 2 * num_channels + 1, dtype=input.dtype, device=input.device ) # Use allgather instead of allreduce because count could be different across # ranks, simple all reduce op can not give correct results. # batch_norm_gather_stats_with_counts calculates global mean & invstd based on # all gathered mean, invstd and count. # for nccl backend, use the optimized version of all gather. if process_group._get_backend_name() == 'nccl': # world_size * (2C + 1) combined_size = combined.numel() combined_flat = torch.empty(1, combined_size * world_size, dtype=combined.dtype, device=combined.device) dist._all_gather_base(combined_flat, combined, process_group, async_op=False) combined = torch.reshape(combined_flat, (world_size, combined_size)) # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1 mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1) else: # world_size * (2C + 1) combined_list = [ torch.empty_like(combined) for _ in range(world_size) ] dist.all_gather(combined_list, combined, process_group, async_op=False) combined = torch.stack(combined_list, dim=0) # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1 mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1) # remove stats from empty inputs mask = count_all.squeeze(-1) >= 1 count_all = count_all[mask] mean_all = mean_all[mask] invstd_all = invstd_all[mask] # calculate global mean & invstd mean, invstd = torch.batch_norm_gather_stats_with_counts( input, mean_all, invstd_all, running_mean, running_var, momentum, eps, count_all.view(-1) ) self.save_for_backward(input, weight, mean, invstd, count_all.to(torch.int32)) self.process_group = process_group # apply element-wise normalization if input.numel() > 0: return torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps) else: return torch.empty_like(input)
def forward(self, input, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size, rank): input = input.contiguous() size = input.numel() // input.size(1) if size == 1: raise ValueError( 'Expected more than 1 value per channel when training, got input size {}' .format(size)) count = torch.Tensor([size]).to(input.device) # calculate mean/invstd for input. mean, invstd = torch.batch_norm_stats(input, eps) count_all = torch.empty(world_size, 1, dtype=count.dtype, device=count.device) mean_all = torch.empty(world_size, mean.size(0), dtype=mean.dtype, device=mean.device) invstd_all = torch.empty(world_size, invstd.size(0), dtype=invstd.dtype, device=invstd.device) count_l = list(count_all.unbind(0)) mean_l = list(mean_all.unbind(0)) invstd_l = list(invstd_all.unbind(0)) # using all_gather instead of all reduce so we can calculate count/mean/var in one go count_all_reduce = torch.distributed.all_gather(count_l, count, process_group, async_op=True) mean_all_reduce = torch.distributed.all_gather(mean_l, mean, process_group, async_op=True) invstd_all_reduce = torch.distributed.all_gather(invstd_l, invstd, process_group, async_op=True) # wait on the async communication to finish count_all_reduce.wait() mean_all_reduce.wait() invstd_all_reduce.wait() ### uncomment to check result: # print('[%d]mean before shuffle:'%rank, mean_l[rank][0:4]) # shuffle global mean & invstd new_rank = forward_shuffle(rank, world_size) count = count_l[new_rank] mean = mean_l[new_rank].view(1, -1) invstd = invstd_l[new_rank].view(1, -1) mean, invstd = torch.batch_norm_gather_stats(input, mean, invstd, running_mean, running_var, momentum, eps, count.long().item()) ### uncomment to check result: # print('[%d]mean after shuffle:'%rank, mean[0:4]) self.save_for_backward(input, weight, mean, invstd) self.process_group = process_group self.world_size = world_size self.rank = rank # apply element-wise normalization out = torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps) return out
def forward(self, input, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size): if not input.is_contiguous(memory_format=torch.channels_last): input = input.contiguous() if weight is not None: weight = weight.contiguous() size = int(input.numel() // input.size(1)) if size == 1 and world_size < 2: raise ValueError( 'Expected more than 1 value per channel when training, got input size {}' .format(size)) # calculate mean/invstd for input. mean, invstd = torch.batch_norm_stats(input, eps) count = torch.full((1, ), input.numel() // input.size(1), dtype=mean.dtype, device=mean.device) num_channels = input.shape[1] # C, C, 1 -> (2C + 1) combined = torch.cat([mean, invstd, count], dim=0) # Use allgather instead of allreduce because count could be different across # ranks, simple all reduce op can not give correct results. # batch_norm_gather_stats_with_counts calculates global mean & invstd based on # all gathered mean, invstd and count. # for nccl backend, use the optimized version of all gather. if process_group._get_backend_name() == 'nccl': # world_size * (2C + 1) combined_size = combined.numel() combined_flat = torch.empty(1, combined_size * world_size, dtype=combined.dtype, device=combined.device) dist._all_gather_base(combined_flat, combined, process_group, async_op=False) combined = torch.reshape(combined_flat, (world_size, combined_size)) # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1 mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1) else: # world_size * (2C + 1) combined_list = [ torch.empty_like(combined) for k in range(world_size) ] dist.all_gather(combined_list, combined, process_group, async_op=False) combined = torch.stack(combined_list, dim=0) # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1 mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1) # calculate global mean & invstd mean, invstd = torch.batch_norm_gather_stats_with_counts( input, mean_all, invstd_all, running_mean, running_var, momentum, eps, count_all.view(-1)) self.save_for_backward(input, weight, mean, invstd, count_all.to(torch.int32)) self.process_group = process_group # apply element-wise normalization out = torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps) return out