class SynchronizedDelinear(Delinear): def __init__(self, in_features, out_features, bias=True, eps=1e-5, n_iter=5, momentum=0.1, block=512): assert ReduceAddCoalesced is not None, 'Can not use SynchronizedDelinear without CUDA support.' super(SynchronizedDelinear, self).__init__(in_features, out_features, bias, eps, n_iter, momentum, block) self._sync_master = SyncMaster(self._data_parallel_master) self._is_parallel = False self._parallel_id = None self._slave_pipe = None def forward(self, input): # If it is not parallel computation or is in evaluation mode, use the original implementation. if not (self._is_parallel and self.training): return super().forward(input) X = input.view(-1, self.block) X_sum = X.sum(0) XY_sum = X.t() @ X sum_size = X.shape[0] # Reduce-and-broadcast the statistics. if self._parallel_id == 0: X_mean, cov_isqrt = self._sync_master.run_master( _ChildMessage(X_sum, XY_sum, sum_size)) else: X_mean, cov_isqrt = self._slave_pipe.run_slave( _ChildMessage(X_sum, XY_sum, sum_size)) w = self.weight.view(-1, self.block) @ cov_isqrt if self.bias is None: b = -(w @ (X_mean.unsqueeze(1))).view(self.weight.shape[0], -1).sum(1) else: b = self.bias - (w @ (X_mean.unsqueeze(1))).view( self.weight.shape[0], -1).sum(1) w = w.view(self.weight.shape) return F.linear(input, w, b) def __data_parallel_replicate__(self, ctx, copy_id): self._is_parallel = True self._parallel_id = copy_id # parallel_id == 0 means master device. if self._parallel_id == 0: ctx.sync_master = self._sync_master else: self._slave_pipe = ctx.sync_master.register_slave(copy_id) def _data_parallel_master(self, intermediates): """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" # Always using same "device order" makes the ReduceAdd operation faster. # Thanks to:: Tete Xiao (http://tetexiao.com/) intermediates = sorted(intermediates, key=lambda i: i[1].X_sum.get_device()) to_reduce = [i[1][:2] for i in intermediates] to_reduce = [j for i in to_reduce for j in i] # flatten target_gpus = [i[1].X_sum.get_device() for i in intermediates] total_sum = sum([i[1].sum_size for i in intermediates]) X_sum, XY_sum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) X_mean, cov_isqrt = self._compute_mean_isqrt(X_sum, XY_sum, total_sum) broadcasted = Broadcast.apply(target_gpus, X_mean, cov_isqrt) outputs = [] for i, rec in enumerate(intermediates): outputs.append( (rec[0], _MasterMessage(*broadcasted[i * 2:i * 2 + 2]))) return outputs def _compute_mean_isqrt(self, X_sum, XY_sum, size): """Compute the mean and standard-deviation with sum and square-sum. This method also maintains the moving average on the master device.""" assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' X_mean = X_sum / size Cov = (XY_sum - (X_sum.unsqueeze(1)) @ X_mean.unsqueeze(0)) / size #check here. Id = torch.eye(Cov.shape[1], dtype=Cov.dtype, device=Cov.device) cov_isqrt = isqrt_newton_schulz_autograd(Cov + self.eps * Id, self.n_iter) self.running_mean.mul_(1 - self.momentum) self.running_mean.add_(X_mean.detach() * self.momentum) self.running_cov_isqrt.mul_(1 - self.momentum) self.running_cov_isqrt.add_(cov_isqrt.detach() * self.momentum) return X_mean, cov_isqrt
class _SynchronizedBatchNorm(_BatchNorm): def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): assert ReduceAddCoalesced is not None, 'Can not use Synchronized Batch Normalization without CUDA support.' super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) self._sync_master = SyncMaster(self._data_parallel_master) self._is_parallel = False self._parallel_id = None self._slave_pipe = None def forward(self, input): # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. if not (self._is_parallel and self.training): return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, self.training, self.momentum, self.eps) # Resize the input to (B, C, -1). input_shape = input.size() input = input.view(input.size(0), self.num_features, -1) # Compute the sum and square-sum. sum_size = input.size(0) * input.size(2) input_sum = _sum_ft(input) input_ssum = _sum_ft(input**2) # Reduce-and-broadcast the statistics. if self._parallel_id == 0: mean, inv_std = self._sync_master.run_master( _ChildMessage(input_sum, input_ssum, sum_size)) else: mean, inv_std = self._slave_pipe.run_slave( _ChildMessage(input_sum, input_ssum, sum_size)) # Compute the output. if self.affine: # MJY:: Fuse the multiplication for speed. output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft( inv_std * self.weight) + _unsqueeze_ft(self.bias) else: output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) # Reshape it. return output.view(input_shape) def __data_parallel_replicate__(self, ctx, copy_id): self._is_parallel = True self._parallel_id = copy_id # parallel_id == 0 means master device. if self._parallel_id == 0: ctx.sync_master = self._sync_master else: self._slave_pipe = ctx.sync_master.register_slave(copy_id) def _data_parallel_master(self, intermediates): """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" # Always using same "device order" makes the ReduceAdd operation faster. # Thanks to:: Tete Xiao (http://tetexiao.com/) intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) to_reduce = [i[1][:2] for i in intermediates] to_reduce = [j for i in to_reduce for j in i] # flatten target_gpus = [i[1].sum.get_device() for i in intermediates] sum_size = sum([i[1].sum_size for i in intermediates]) sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) broadcasted = Broadcast.apply(target_gpus, mean, inv_std) outputs = [] for i, rec in enumerate(intermediates): outputs.append( (rec[0], _MasterMessage(*broadcasted[i * 2:i * 2 + 2]))) return outputs def _compute_mean_std(self, sum_, ssum, size): """Compute the mean and standard-deviation with sum and square-sum. This method also maintains the moving average on the master device.""" assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' mean = sum_ / size sumvar = ssum - sum_ * mean unbias_var = sumvar / (size - 1) bias_var = sumvar / size if hasattr(torch, 'no_grad'): with torch.no_grad(): self.running_mean = ( 1 - self.momentum ) * self.running_mean + self.momentum * mean.data self.running_var = ( 1 - self.momentum ) * self.running_var + self.momentum * unbias_var.data else: self.running_mean = ( 1 - self.momentum) * self.running_mean + self.momentum * mean.data self.running_var = ( 1 - self.momentum ) * self.running_var + self.momentum * unbias_var.data return mean, bias_var.clamp(self.eps)**-0.5
class SynchronizedDeconvTransposed(FastDeconvTransposed): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1, eps=1e-5, n_iter=5, momentum=0.1, block=64, sampling_stride=3): assert ReduceAddCoalesced is not None, 'Can not use SynchronizedDeconvTransposed without CUDA support.' super(SynchronizedDeconv, self).__init__(in_channels, out_channels, kernel_size, stride, padding, output_padding, groups, bias, dilation, eps, n_iter, momentum, block, sampling_stride) if not self.track_running_stats: import warnings warnings.warn( 'track_running_stats=False is not supported by the SynchronizedDeconvTransposed.' ) self._sync_master = SyncMaster(self._data_parallel_master) self._is_parallel = False self._parallel_id = None self._slave_pipe = None def forward(self, x, output_size=None): # If it is not parallel computation or is in evaluation mode, use the original implementation. if not (self._is_parallel and self.training): return super().forward(x, output_size) N, C, H, W = x.shape B = self.block if self.training and self.track_running_stats: self.counter += 1 #im2col: N x cols x pixels -> N*pixles x cols if self.kernel_size[0] > 1: #the adjoint of a conv operation is a full correlation operation. So pad first. padding = (self.padding[0] + self.kernel_size[0] - 1, self.padding[1] + self.kernel_size[1] - 1) X = torch.nn.functional.unfold(x, self.kernel_size, self.dilation, self.padding, self.sampling_stride).transpose( 1, 2).contiguous() else: #channel wise X = x.permute(0, 2, 3, 1).contiguous().view(-1, C)[::self.sampling_stride**2, :] if self.groups == 1: # (C//B*N*pixels,k*k*B) X = X.view(-1, self.num_features, C // B).transpose(1, 2).contiguous().view( -1, self.num_features) else: X = X.view(-1, X.shape[-1]) X_sum = X.sum(0) if self.groups == 1: XY_sum = X.t() @ X sum_size = X.shape[0] else: XY_sum = X.transpose(1, 2) @ X sum_size = X.shape[1] # Reduce-and-broadcast the statistics. if self._parallel_id == 0: X_mean, cov_isqrt = self._sync_master.run_master( _ChildMessage(X_sum, XY_sum, sum_size)) else: X_mean, cov_isqrt = self._slave_pipe.run_slave( _ChildMessage(X_sum, XY_sum, sum_size)) # X * deconv * conv = X * (deconv * conv) weight = torch.flip(self.weight, [2, 3]) if self.groups == 1: w = weight.view(-1, self.num_features, C // B).transpose( 1, 2).contiguous().view(-1, self.num_features) @ cov_isqrt b = self.bias - (w @ (X_mean.unsqueeze(1))).view( self.weight.shape[0], -1).sum(1) w = w.view(-1, C // B, self.num_features).transpose(1, 2).contiguous() else: w = self.weight.view(C // B, -1, self.num_features) @ cov_isqrt b = self.bias - (w @ (X_mean.view(-1, self.num_features, 1))).view( self.bias.shape) w = w.view(self.weight.shape) w = torch.flip(w.view(weight.shape), [2, 3]) output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size) x = F.conv_transpose2d(x, w, b, self.stride, self.padding, output_padding, self.groups, self.dilation) return x def __data_parallel_replicate__(self, ctx, copy_id): self._is_parallel = True self._parallel_id = copy_id # parallel_id == 0 means master device. if self._parallel_id == 0: ctx.sync_master = self._sync_master else: self._slave_pipe = ctx.sync_master.register_slave(copy_id) def _data_parallel_master(self, intermediates): """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" # Always using same "device order" makes the ReduceAdd operation faster. # Thanks to:: Tete Xiao (http://tetexiao.com/) intermediates = sorted(intermediates, key=lambda i: i[1].X_sum.get_device()) to_reduce = [i[1][:2] for i in intermediates] to_reduce = [j for i in to_reduce for j in i] # flatten target_gpus = [i[1].X_sum.get_device() for i in intermediates] total_sum = sum([i[1].sum_size for i in intermediates]) X_sum, XY_sum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) X_mean, cov_isqrt = self._compute_mean_isqrt(X_sum, XY_sum, total_sum) broadcasted = Broadcast.apply(target_gpus, X_mean, cov_isqrt) outputs = [] for i, rec in enumerate(intermediates): outputs.append( (rec[0], _MasterMessage(*broadcasted[i * 2:i * 2 + 2]))) return outputs def _compute_mean_isqrt(self, X_sum, XY_sum, size): """Compute the mean and standard-deviation with sum and square-sum. This method also maintains the moving average on the master device.""" assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' X_mean = X_sum / size if self.groups == 1: Cov = (XY_sum - (X_sum.unsqueeze(1)) @ X_mean.unsqueeze(0) ) / size #check here. Id = torch.eye(Cov.shape[1], dtype=Cov.dtype, device=Cov.device) cov_isqrt = isqrt_newton_schulz_autograd(Cov + self.eps * Id, self.n_iter) else: Cov = (XY_sum - (X_sum.unsqueeze(2)) @ X_mean.unsqueeze(1) ) / size #check here. Id = torch.eye(self.num_features, dtype=Cov.dtype, device=Cov.device).expand(self.groups, self.num_features, self.num_features) cov_isqrt = isqrt_newton_schulz_autograd_batch( Cov + self.eps * Id, self.n_iter) self.running_mean.mul_(1 - self.momentum) self.running_mean.add_(X_mean.detach() * self.momentum) self.running_cov_isqrt.mul_(1 - self.momentum) self.running_cov_isqrt.add_(cov_isqrt.detach() * self.momentum) return X_mean, cov_isqrt
class _SynchronizedBatchNorm(_BatchNorm): def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True): assert ReduceAddCoalesced is not None, 'Can not use Synchronized Batch Normalization without CUDA support.' super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats) if not self.track_running_stats: import warnings warnings.warn( 'track_running_stats=False is not supported by the SynchronizedBatchNorm.' ) self._sync_master = SyncMaster(self._data_parallel_master) self._is_parallel = False self._parallel_id = None self._slave_pipe = None def forward(self, input): if not (self._is_parallel and self.training): return F.batch_norm( input, self.running_mean if self.track_running_stats else None, self.running_var if self.track_running_stats else None, self.weight, self.bias, self.training if self.track_running_stats else True, self.momentum, self.eps) input_shape = input.size() input = input.view(input.size(0), self.num_features, -1) sum_size = input.size(0) * input.size(2) input_sum = _sum_ft(input) input_ssum = _sum_ft(input**2) if self._parallel_id == 0: mean, inv_std = self._sync_master.run_master( _ChildMessage(input_sum, input_ssum, sum_size)) else: mean, inv_std = self._slave_pipe.run_slave( _ChildMessage(input_sum, input_ssum, sum_size)) if self.affine: output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft( inv_std * self.weight) + _unsqueeze_ft(self.bias) else: output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) return output.view(input_shape) def __data_parallel_replicate__(self, ctx, copy_id): self._is_parallel = True self._parallel_id = copy_id if self._parallel_id == 0: ctx.sync_master = self._sync_master else: self._slave_pipe = ctx.sync_master.register_slave(copy_id) def _data_parallel_master(self, intermediates): """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) to_reduce = [i[1][:2] for i in intermediates] to_reduce = [j for i in to_reduce for j in i] target_gpus = [i[1].sum.get_device() for i in intermediates] sum_size = sum([i[1].sum_size for i in intermediates]) sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) broadcasted = Broadcast.apply(target_gpus, mean, inv_std) outputs = [] for i, rec in enumerate(intermediates): outputs.append( (rec[0], _MasterMessage(*broadcasted[i * 2:i * 2 + 2]))) return outputs def _compute_mean_std(self, sum_, ssum, size): """Compute the mean and standard-deviation with sum and square-sum. This method also maintains the moving average on the master device.""" assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' mean = sum_ / size sumvar = ssum - sum_ * mean unbias_var = sumvar / (size - 1) bias_var = sumvar / size if hasattr(torch, 'no_grad'): with torch.no_grad(): self.running_mean = ( 1 - self.momentum ) * self.running_mean + self.momentum * mean.data self.running_var = ( 1 - self.momentum ) * self.running_var + self.momentum * unbias_var.data else: self.running_mean = ( 1 - self.momentum) * self.running_mean + self.momentum * mean.data self.running_var = ( 1 - self.momentum ) * self.running_var + self.momentum * unbias_var.data return mean, bias_var.clamp(self.eps)**-0.5