Ejemplo n.º 1
0
class SynchronizedDelinear(Delinear):
    def __init__(self,
                 in_features,
                 out_features,
                 bias=True,
                 eps=1e-5,
                 n_iter=5,
                 momentum=0.1,
                 block=512):
        assert ReduceAddCoalesced is not None, 'Can not use SynchronizedDelinear without CUDA support.'

        super(SynchronizedDelinear,
              self).__init__(in_features, out_features, bias, eps, n_iter,
                             momentum, block)

        self._sync_master = SyncMaster(self._data_parallel_master)

        self._is_parallel = False
        self._parallel_id = None
        self._slave_pipe = None

    def forward(self, input):
        # If it is not parallel computation or is in evaluation mode, use the original implementation.
        if not (self._is_parallel and self.training):
            return super().forward(input)

        X = input.view(-1, self.block)

        X_sum = X.sum(0)
        XY_sum = X.t() @ X
        sum_size = X.shape[0]

        # Reduce-and-broadcast the statistics.
        if self._parallel_id == 0:
            X_mean, cov_isqrt = self._sync_master.run_master(
                _ChildMessage(X_sum, XY_sum, sum_size))
        else:
            X_mean, cov_isqrt = self._slave_pipe.run_slave(
                _ChildMessage(X_sum, XY_sum, sum_size))

        w = self.weight.view(-1, self.block) @ cov_isqrt

        if self.bias is None:
            b = -(w @ (X_mean.unsqueeze(1))).view(self.weight.shape[0],
                                                  -1).sum(1)
        else:
            b = self.bias - (w @ (X_mean.unsqueeze(1))).view(
                self.weight.shape[0], -1).sum(1)

        w = w.view(self.weight.shape)
        return F.linear(input, w, b)

    def __data_parallel_replicate__(self, ctx, copy_id):
        self._is_parallel = True
        self._parallel_id = copy_id

        # parallel_id == 0 means master device.
        if self._parallel_id == 0:
            ctx.sync_master = self._sync_master
        else:
            self._slave_pipe = ctx.sync_master.register_slave(copy_id)

    def _data_parallel_master(self, intermediates):
        """Reduce the sum and square-sum, compute the statistics, and broadcast it."""

        # Always using same "device order" makes the ReduceAdd operation faster.
        # Thanks to:: Tete Xiao (http://tetexiao.com/)
        intermediates = sorted(intermediates,
                               key=lambda i: i[1].X_sum.get_device())

        to_reduce = [i[1][:2] for i in intermediates]
        to_reduce = [j for i in to_reduce for j in i]  # flatten
        target_gpus = [i[1].X_sum.get_device() for i in intermediates]

        total_sum = sum([i[1].sum_size for i in intermediates])
        X_sum, XY_sum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
        X_mean, cov_isqrt = self._compute_mean_isqrt(X_sum, XY_sum, total_sum)

        broadcasted = Broadcast.apply(target_gpus, X_mean, cov_isqrt)

        outputs = []
        for i, rec in enumerate(intermediates):
            outputs.append(
                (rec[0], _MasterMessage(*broadcasted[i * 2:i * 2 + 2])))

        return outputs

    def _compute_mean_isqrt(self, X_sum, XY_sum, size):
        """Compute the mean and standard-deviation with sum and square-sum. This method
        also maintains the moving average on the master device."""
        assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
        X_mean = X_sum / size

        Cov = (XY_sum -
               (X_sum.unsqueeze(1)) @ X_mean.unsqueeze(0)) / size  #check here.

        Id = torch.eye(Cov.shape[1], dtype=Cov.dtype, device=Cov.device)

        cov_isqrt = isqrt_newton_schulz_autograd(Cov + self.eps * Id,
                                                 self.n_iter)

        self.running_mean.mul_(1 - self.momentum)
        self.running_mean.add_(X_mean.detach() * self.momentum)
        self.running_cov_isqrt.mul_(1 - self.momentum)
        self.running_cov_isqrt.add_(cov_isqrt.detach() * self.momentum)
        return X_mean, cov_isqrt
Ejemplo n.º 2
0
class _SynchronizedBatchNorm(_BatchNorm):
    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
        assert ReduceAddCoalesced is not None, 'Can not use Synchronized Batch Normalization without CUDA support.'

        super(_SynchronizedBatchNorm, self).__init__(num_features,
                                                     eps=eps,
                                                     momentum=momentum,
                                                     affine=affine)

        self._sync_master = SyncMaster(self._data_parallel_master)

        self._is_parallel = False
        self._parallel_id = None
        self._slave_pipe = None

    def forward(self, input):
        # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
        if not (self._is_parallel and self.training):
            return F.batch_norm(input, self.running_mean, self.running_var,
                                self.weight, self.bias, self.training,
                                self.momentum, self.eps)

        # Resize the input to (B, C, -1).
        input_shape = input.size()
        input = input.view(input.size(0), self.num_features, -1)

        # Compute the sum and square-sum.
        sum_size = input.size(0) * input.size(2)
        input_sum = _sum_ft(input)
        input_ssum = _sum_ft(input**2)

        # Reduce-and-broadcast the statistics.
        if self._parallel_id == 0:
            mean, inv_std = self._sync_master.run_master(
                _ChildMessage(input_sum, input_ssum, sum_size))
        else:
            mean, inv_std = self._slave_pipe.run_slave(
                _ChildMessage(input_sum, input_ssum, sum_size))

        # Compute the output.
        if self.affine:
            # MJY:: Fuse the multiplication for speed.
            output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(
                inv_std * self.weight) + _unsqueeze_ft(self.bias)
        else:
            output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)

        # Reshape it.
        return output.view(input_shape)

    def __data_parallel_replicate__(self, ctx, copy_id):
        self._is_parallel = True
        self._parallel_id = copy_id

        # parallel_id == 0 means master device.
        if self._parallel_id == 0:
            ctx.sync_master = self._sync_master
        else:
            self._slave_pipe = ctx.sync_master.register_slave(copy_id)

    def _data_parallel_master(self, intermediates):
        """Reduce the sum and square-sum, compute the statistics, and broadcast it."""

        # Always using same "device order" makes the ReduceAdd operation faster.
        # Thanks to:: Tete Xiao (http://tetexiao.com/)
        intermediates = sorted(intermediates,
                               key=lambda i: i[1].sum.get_device())

        to_reduce = [i[1][:2] for i in intermediates]
        to_reduce = [j for i in to_reduce for j in i]  # flatten
        target_gpus = [i[1].sum.get_device() for i in intermediates]

        sum_size = sum([i[1].sum_size for i in intermediates])
        sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
        mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)

        broadcasted = Broadcast.apply(target_gpus, mean, inv_std)

        outputs = []
        for i, rec in enumerate(intermediates):
            outputs.append(
                (rec[0], _MasterMessage(*broadcasted[i * 2:i * 2 + 2])))

        return outputs

    def _compute_mean_std(self, sum_, ssum, size):
        """Compute the mean and standard-deviation with sum and square-sum. This method
        also maintains the moving average on the master device."""
        assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
        mean = sum_ / size
        sumvar = ssum - sum_ * mean
        unbias_var = sumvar / (size - 1)
        bias_var = sumvar / size

        if hasattr(torch, 'no_grad'):
            with torch.no_grad():
                self.running_mean = (
                    1 - self.momentum
                ) * self.running_mean + self.momentum * mean.data
                self.running_var = (
                    1 - self.momentum
                ) * self.running_var + self.momentum * unbias_var.data
        else:
            self.running_mean = (
                1 -
                self.momentum) * self.running_mean + self.momentum * mean.data
            self.running_var = (
                1 - self.momentum
            ) * self.running_var + self.momentum * unbias_var.data

        return mean, bias_var.clamp(self.eps)**-0.5
Ejemplo n.º 3
0
class SynchronizedDeconvTransposed(FastDeconvTransposed):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 output_padding=0,
                 groups=1,
                 bias=True,
                 dilation=1,
                 eps=1e-5,
                 n_iter=5,
                 momentum=0.1,
                 block=64,
                 sampling_stride=3):
        assert ReduceAddCoalesced is not None, 'Can not use SynchronizedDeconvTransposed without CUDA support.'

        super(SynchronizedDeconv,
              self).__init__(in_channels, out_channels, kernel_size, stride,
                             padding, output_padding, groups, bias, dilation,
                             eps, n_iter, momentum, block, sampling_stride)

        if not self.track_running_stats:
            import warnings
            warnings.warn(
                'track_running_stats=False is not supported by the SynchronizedDeconvTransposed.'
            )

        self._sync_master = SyncMaster(self._data_parallel_master)

        self._is_parallel = False
        self._parallel_id = None
        self._slave_pipe = None

    def forward(self, x, output_size=None):
        # If it is not parallel computation or is in evaluation mode, use the original implementation.
        if not (self._is_parallel and self.training):
            return super().forward(x, output_size)

        N, C, H, W = x.shape
        B = self.block

        if self.training and self.track_running_stats:
            self.counter += 1

        #im2col: N x cols x pixels -> N*pixles x cols
        if self.kernel_size[0] > 1:
            #the adjoint of a conv operation is a full correlation operation. So pad first.
            padding = (self.padding[0] + self.kernel_size[0] - 1,
                       self.padding[1] + self.kernel_size[1] - 1)
            X = torch.nn.functional.unfold(x, self.kernel_size, self.dilation,
                                           self.padding,
                                           self.sampling_stride).transpose(
                                               1, 2).contiguous()
        else:
            #channel wise
            X = x.permute(0, 2, 3,
                          1).contiguous().view(-1,
                                               C)[::self.sampling_stride**2, :]

        if self.groups == 1:
            # (C//B*N*pixels,k*k*B)
            X = X.view(-1, self.num_features,
                       C // B).transpose(1, 2).contiguous().view(
                           -1, self.num_features)
        else:
            X = X.view(-1, X.shape[-1])

        X_sum = X.sum(0)

        if self.groups == 1:
            XY_sum = X.t() @ X
            sum_size = X.shape[0]
        else:
            XY_sum = X.transpose(1, 2) @ X
            sum_size = X.shape[1]

        # Reduce-and-broadcast the statistics.
        if self._parallel_id == 0:
            X_mean, cov_isqrt = self._sync_master.run_master(
                _ChildMessage(X_sum, XY_sum, sum_size))
        else:
            X_mean, cov_isqrt = self._slave_pipe.run_slave(
                _ChildMessage(X_sum, XY_sum, sum_size))

        # X * deconv * conv = X * (deconv * conv)
        weight = torch.flip(self.weight, [2, 3])

        if self.groups == 1:
            w = weight.view(-1, self.num_features, C // B).transpose(
                1, 2).contiguous().view(-1, self.num_features) @ cov_isqrt
            b = self.bias - (w @ (X_mean.unsqueeze(1))).view(
                self.weight.shape[0], -1).sum(1)
            w = w.view(-1, C // B,
                       self.num_features).transpose(1, 2).contiguous()
        else:
            w = self.weight.view(C // B, -1, self.num_features) @ cov_isqrt
            b = self.bias - (w @ (X_mean.view(-1, self.num_features, 1))).view(
                self.bias.shape)

        w = w.view(self.weight.shape)

        w = torch.flip(w.view(weight.shape), [2, 3])

        output_padding = self._output_padding(x, output_size, self.stride,
                                              self.padding, self.kernel_size)
        x = F.conv_transpose2d(x, w, b, self.stride, self.padding,
                               output_padding, self.groups, self.dilation)

        return x

    def __data_parallel_replicate__(self, ctx, copy_id):
        self._is_parallel = True
        self._parallel_id = copy_id

        # parallel_id == 0 means master device.
        if self._parallel_id == 0:
            ctx.sync_master = self._sync_master
        else:
            self._slave_pipe = ctx.sync_master.register_slave(copy_id)

    def _data_parallel_master(self, intermediates):
        """Reduce the sum and square-sum, compute the statistics, and broadcast it."""

        # Always using same "device order" makes the ReduceAdd operation faster.
        # Thanks to:: Tete Xiao (http://tetexiao.com/)
        intermediates = sorted(intermediates,
                               key=lambda i: i[1].X_sum.get_device())

        to_reduce = [i[1][:2] for i in intermediates]
        to_reduce = [j for i in to_reduce for j in i]  # flatten
        target_gpus = [i[1].X_sum.get_device() for i in intermediates]

        total_sum = sum([i[1].sum_size for i in intermediates])
        X_sum, XY_sum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
        X_mean, cov_isqrt = self._compute_mean_isqrt(X_sum, XY_sum, total_sum)

        broadcasted = Broadcast.apply(target_gpus, X_mean, cov_isqrt)

        outputs = []
        for i, rec in enumerate(intermediates):
            outputs.append(
                (rec[0], _MasterMessage(*broadcasted[i * 2:i * 2 + 2])))

        return outputs

    def _compute_mean_isqrt(self, X_sum, XY_sum, size):
        """Compute the mean and standard-deviation with sum and square-sum. This method
        also maintains the moving average on the master device."""
        assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
        X_mean = X_sum / size

        if self.groups == 1:
            Cov = (XY_sum - (X_sum.unsqueeze(1)) @ X_mean.unsqueeze(0)
                   ) / size  #check here.
            Id = torch.eye(Cov.shape[1], dtype=Cov.dtype, device=Cov.device)
            cov_isqrt = isqrt_newton_schulz_autograd(Cov + self.eps * Id,
                                                     self.n_iter)
        else:
            Cov = (XY_sum - (X_sum.unsqueeze(2)) @ X_mean.unsqueeze(1)
                   ) / size  #check here.
            Id = torch.eye(self.num_features,
                           dtype=Cov.dtype,
                           device=Cov.device).expand(self.groups,
                                                     self.num_features,
                                                     self.num_features)
            cov_isqrt = isqrt_newton_schulz_autograd_batch(
                Cov + self.eps * Id, self.n_iter)

        self.running_mean.mul_(1 - self.momentum)
        self.running_mean.add_(X_mean.detach() * self.momentum)
        self.running_cov_isqrt.mul_(1 - self.momentum)
        self.running_cov_isqrt.add_(cov_isqrt.detach() * self.momentum)

        return X_mean, cov_isqrt
Ejemplo n.º 4
0
class _SynchronizedBatchNorm(_BatchNorm):
    def __init__(self,
                 num_features,
                 eps=1e-5,
                 momentum=0.1,
                 affine=True,
                 track_running_stats=True):
        assert ReduceAddCoalesced is not None, 'Can not use Synchronized Batch Normalization without CUDA support.'

        super(_SynchronizedBatchNorm,
              self).__init__(num_features,
                             eps=eps,
                             momentum=momentum,
                             affine=affine,
                             track_running_stats=track_running_stats)

        if not self.track_running_stats:
            import warnings
            warnings.warn(
                'track_running_stats=False is not supported by the SynchronizedBatchNorm.'
            )

        self._sync_master = SyncMaster(self._data_parallel_master)

        self._is_parallel = False
        self._parallel_id = None
        self._slave_pipe = None

    def forward(self, input):
        if not (self._is_parallel and self.training):
            return F.batch_norm(
                input, self.running_mean if self.track_running_stats else None,
                self.running_var if self.track_running_stats else None,
                self.weight, self.bias,
                self.training if self.track_running_stats else True,
                self.momentum, self.eps)

        input_shape = input.size()
        input = input.view(input.size(0), self.num_features, -1)

        sum_size = input.size(0) * input.size(2)
        input_sum = _sum_ft(input)
        input_ssum = _sum_ft(input**2)

        if self._parallel_id == 0:
            mean, inv_std = self._sync_master.run_master(
                _ChildMessage(input_sum, input_ssum, sum_size))
        else:
            mean, inv_std = self._slave_pipe.run_slave(
                _ChildMessage(input_sum, input_ssum, sum_size))

        if self.affine:
            output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(
                inv_std * self.weight) + _unsqueeze_ft(self.bias)
        else:
            output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)

        return output.view(input_shape)

    def __data_parallel_replicate__(self, ctx, copy_id):
        self._is_parallel = True
        self._parallel_id = copy_id

        if self._parallel_id == 0:
            ctx.sync_master = self._sync_master
        else:
            self._slave_pipe = ctx.sync_master.register_slave(copy_id)

    def _data_parallel_master(self, intermediates):
        """Reduce the sum and square-sum, compute the statistics, and broadcast it."""

        intermediates = sorted(intermediates,
                               key=lambda i: i[1].sum.get_device())

        to_reduce = [i[1][:2] for i in intermediates]
        to_reduce = [j for i in to_reduce for j in i]
        target_gpus = [i[1].sum.get_device() for i in intermediates]

        sum_size = sum([i[1].sum_size for i in intermediates])
        sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
        mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)

        broadcasted = Broadcast.apply(target_gpus, mean, inv_std)

        outputs = []
        for i, rec in enumerate(intermediates):
            outputs.append(
                (rec[0], _MasterMessage(*broadcasted[i * 2:i * 2 + 2])))

        return outputs

    def _compute_mean_std(self, sum_, ssum, size):
        """Compute the mean and standard-deviation with sum and square-sum. This method
        also maintains the moving average on the master device."""
        assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
        mean = sum_ / size
        sumvar = ssum - sum_ * mean
        unbias_var = sumvar / (size - 1)
        bias_var = sumvar / size

        if hasattr(torch, 'no_grad'):
            with torch.no_grad():
                self.running_mean = (
                    1 - self.momentum
                ) * self.running_mean + self.momentum * mean.data
                self.running_var = (
                    1 - self.momentum
                ) * self.running_var + self.momentum * unbias_var.data
        else:
            self.running_mean = (
                1 -
                self.momentum) * self.running_mean + self.momentum * mean.data
            self.running_var = (
                1 - self.momentum
            ) * self.running_var + self.momentum * unbias_var.data

        return mean, bias_var.clamp(self.eps)**-0.5