Beispiel #1
0
    def apply(module: Module, name: str, n_power_iterations: int, dim: int,
              eps: float) -> 'SpectralNorm':
        for k, hook in module._forward_pre_hooks.items():
            if isinstance(hook, SpectralNorm) and hook.name == name:
                raise RuntimeError(
                    "Cannot register two spectral_norm hooks on "
                    "the same parameter {}".format(name))

        fn = SpectralNorm(name, n_power_iterations, dim, eps)
        weight = module._parameters[name]

        with torch.no_grad():
            weight_mat = fn.reshape_weight_to_matrix(weight)

            h, w = weight_mat.size()
            # randomly initialize `u` and `v`
            u = normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=fn.eps)
            v = normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=fn.eps)

        delattr(module, fn.name)
        module.register_parameter(fn.name + "_orig", weight)
        # We still need to assign weight back as fn.name because all sorts of
        # things may assume that it exists, e.g., when initializing weights.
        # However, we can't directly assign as it could be an nn.Parameter and
        # gets added as a parameter. Instead, we register weight.data as a plain
        # attribute.
        setattr(module, fn.name, weight.data)
        module.register_buffer(fn.name + "_u", u)
        module.register_buffer(fn.name + "_v", v)

        module.register_forward_pre_hook(fn)
        module._register_state_dict_hook(SpectralNormStateDictHook(fn))
        module._register_load_state_dict_pre_hook(
            SpectralNormLoadStateDictPreHook(fn))
        return fn
Beispiel #2
0
 def __call__(self, state_dict, prefix, local_metadata, strict,
              missing_keys, unexpected_keys, error_msgs) -> None:
     fn = self.fn
     version = local_metadata.get('spectral_norm',
                                  {}).get(fn.name + '.version', None)
     if version is None or version < 1:
         weight_key = prefix + fn.name
         if version is None and all(weight_key + s in state_dict for s in ('_orig', '_u', '_v')) and \
                 weight_key not in state_dict:
             # Detect if it is the updated state dict and just missing metadata.
             # This could happen if the users are crafting a state dict themselves,
             # so we just pretend that this is the newest.
             return
         has_missing_keys = False
         for suffix in ('_orig', '', '_u'):
             key = weight_key + suffix
             if key not in state_dict:
                 has_missing_keys = True
                 if strict:
                     missing_keys.append(key)
         if has_missing_keys:
             return
         with torch.no_grad():
             weight_orig = state_dict[weight_key + '_orig']
             weight = state_dict.pop(weight_key)
             sigma = (weight_orig / weight).mean()
             weight_mat = fn.reshape_weight_to_matrix(weight_orig)
             u = state_dict[weight_key + '_u']
             v = fn._solve_v_and_rescale(weight_mat, u, sigma)
             state_dict[weight_key + '_v'] = v
Beispiel #3
0
def sparse_(tensor, sparsity, std=0.01):
    r"""Fills the 2D input `Tensor` as a sparse matrix, where the
    non-zero elements will be drawn from the normal distribution
    :math:`\mathcal{N}(0, 0.01)`, as described in `Deep learning via
    Hessian-free optimization` - Martens, J. (2010).

    Args:
        tensor: an n-dimensional `torch.Tensor`
        sparsity: The fraction of elements in each column to be set to zero
        std: the standard deviation of the normal distribution used to generate
            the non-zero values

    Examples:
        >>> w = torch.empty(3, 5)
        >>> nn.init.sparse_(w, sparsity=0.1)
    """
    if tensor.ndimension() != 2:
        raise ValueError("Only tensors with 2 dimensions are supported")

    rows, cols = tensor.shape
    num_zeros = int(math.ceil(sparsity * rows))

    with torch.no_grad():
        tensor.normal_(0, std)
        for col_idx in range(cols):
            row_indices = torch.randperm(rows)
            zero_indices = row_indices[:num_zeros]
            tensor[zero_indices, col_idx] = 0
    return tensor
Beispiel #4
0
def kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'):
    r"""Fills the input `Tensor` with values according to the method
    described in `Delving deep into rectifiers: Surpassing human-level
    performance on ImageNet classification` - He, K. et al. (2015), using a
    normal distribution. The resulting tensor will have values sampled from
    :math:`\mathcal{N}(0, \text{std}^2)` where

    .. math::
        \text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}}

    Also known as He initialization.

    Args:
        tensor: an n-dimensional `torch.Tensor`
        a: the negative slope of the rectifier used after this layer (only
            used with ``'leaky_relu'``)
        mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
            preserves the magnitude of the variance of the weights in the
            forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
            backwards pass.
        nonlinearity: the non-linear function (`nn.functional` name),
            recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).

    Examples:
        >>> w = torch.empty(3, 5)
        >>> nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu')
    """
    fan = _calculate_correct_fan(tensor, mode)
    gain = calculate_gain(nonlinearity, a)
    std = gain / math.sqrt(fan)
    with torch.no_grad():
        return tensor.normal_(0, std)
Beispiel #5
0
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    def norm_cdf(x):
        # Computes standard normal cumulative distribution function
        return (1. + math.erf(x / math.sqrt(2.))) / 2.

    if (mean < a - 2 * std) or (mean > b + 2 * std):
        warnings.warn(
            "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
            "The distribution of values may be incorrect.",
            stacklevel=2)

    with torch.no_grad():
        # Values are generated by using a truncated uniform distribution and
        # then using the inverse CDF for the normal distribution.
        # Get upper and lower cdf values
        l = norm_cdf((a - mean) / std)
        u = norm_cdf((b - mean) / std)

        # Uniformly fill tensor with values from [l, u], then translate to
        # [2l-1, 2u-1].
        tensor.uniform_(2 * l - 1, 2 * u - 1)

        # Use inverse cdf transform for normal distribution to get truncated
        # standard normal
        tensor.erfinv_()

        # Transform to proper mean, std
        tensor.mul_(std * math.sqrt(2.))
        tensor.add_(mean)

        # Clamp to ensure it's in the proper range
        tensor.clamp_(min=a, max=b)
        return tensor
Beispiel #6
0
    def compute_weight(self, module: Module,
                       do_power_iteration: bool) -> torch.Tensor:
        # NB: If `do_power_iteration` is set, the `u` and `v` vectors are
        #     updated in power iteration **in-place**. This is very important
        #     because in `DataParallel` forward, the vectors (being buffers) are
        #     broadcast from the parallelized module to each module replica,
        #     which is a new module object created on the fly. And each replica
        #     runs its own spectral norm power iteration. So simply assigning
        #     the updated vectors to the module this function runs on will cause
        #     the update to be lost forever. And the next time the parallelized
        #     module is replicated, the same randomly initialized vectors are
        #     broadcast and used!
        #
        #     Therefore, to make the change propagate back, we rely on two
        #     important behaviors (also enforced via tests):
        #       1. `DataParallel` doesn't clone storage if the broadcast tensor
        #          is already on correct device; and it makes sure that the
        #          parallelized module is already on `device[0]`.
        #       2. If the out tensor in `out=` kwarg has correct shape, it will
        #          just fill in the values.
        #     Therefore, since the same power iteration is performed on all
        #     devices, simply updating the tensors in-place will make sure that
        #     the module replica on `device[0]` will update the _u vector on the
        #     parallized module (by shared storage).
        #
        #    However, after we update `u` and `v` in-place, we need to **clone**
        #    them before using them to normalize the weight. This is to support
        #    backproping through two forward passes, e.g., the common pattern in
        #    GAN training: loss = D(real) - D(fake). Otherwise, engine will
        #    complain that variables needed to do backward for the first forward
        #    (i.e., the `u` and `v` vectors) are changed in the second forward.
        weight = getattr(module, self.name + '_orig')
        u = getattr(module, self.name + '_u')
        v = getattr(module, self.name + '_v')
        weight_mat = self.reshape_weight_to_matrix(weight)

        if do_power_iteration:
            with torch.no_grad():
                for _ in range(self.n_power_iterations):
                    # Spectral norm of weight equals to `u^T W v`, where `u` and `v`
                    # are the first left and right singular vectors.
                    # This power iteration produces approximations of `u` and `v`.
                    v = normalize(torch.mv(weight_mat.t(), u),
                                  dim=0,
                                  eps=self.eps,
                                  out=v)
                    u = normalize(torch.mv(weight_mat, v),
                                  dim=0,
                                  eps=self.eps,
                                  out=u)
                if self.n_power_iterations > 0:
                    # See above on why we need to clone
                    u = u.clone(memory_format=torch.contiguous_format)
                    v = v.clone(memory_format=torch.contiguous_format)

        sigma = torch.dot(u, torch.mv(weight_mat, v))
        weight = weight / sigma
        return weight
Beispiel #7
0
 def remove(self, module: Module) -> None:
     with torch.no_grad():
         weight = self.compute_weight(module, do_power_iteration=False)
     delattr(module, self.name)
     delattr(module, self.name + '_u')
     delattr(module, self.name + '_v')
     delattr(module, self.name + '_orig')
     module.register_parameter(self.name,
                               torch.nn.Parameter(weight.detach()))
Beispiel #8
0
    def convert_sync_batchnorm(cls, module, process_group=None):
        r"""Helper function to convert all :attr:`BatchNorm*D` layers in the model to
        :class:`torch.nn.SyncBatchNorm` layers.

        Args:
            module (nn.Module): module containing one or more attr:`BatchNorm*D` layers
            process_group (optional): process group to scope synchronization,
                default is the whole world

        Returns:
            The original :attr:`module` with the converted :class:`torch.nn.SyncBatchNorm`
            layers. If the original :attr:`module` is a :attr:`BatchNorm*D` layer,
            a new :class:`torch.nn.SyncBatchNorm` layer object will be returned
            instead.

        Example::

            >>> # Network with nn.BatchNorm layer
            >>> module = torch.nn.Sequential(
            >>>            torch.nn.Linear(20, 100),
            >>>            torch.nn.BatchNorm1d(100),
            >>>          ).cuda()
            >>> # creating process group (optional)
            >>> # process_ids is a list of int identifying rank ids.
            >>> process_group = torch.distributed.new_group(process_ids)
            >>> sync_bn_module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module, process_group)

        """
        module_output = module
        if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
            module_output = torch.nn.SyncBatchNorm(module.num_features,
                                                   module.eps, module.momentum,
                                                   module.affine,
                                                   module.track_running_stats,
                                                   process_group)
            if module.affine:
                with torch.no_grad():
                    module_output.weight = module.weight
                    module_output.bias = module.bias
            module_output.running_mean = module.running_mean
            module_output.running_var = module.running_var
            module_output.num_batches_tracked = module.num_batches_tracked
        for name, child in module.named_children():
            module_output.add_module(
                name, cls.convert_sync_batchnorm(child, process_group))
        del module
        return module_output
Beispiel #9
0
def dirac_(tensor, groups=1):
    r"""Fills the {3, 4, 5}-dimensional input `Tensor` with the Dirac
    delta function. Preserves the identity of the inputs in `Convolutional`
    layers, where as many input channels are preserved as possible. In case
    of groups>1, each group of channels preserves identity

    Args:
        tensor: a {3, 4, 5}-dimensional `torch.Tensor`
        groups (optional): number of groups in the conv layer (default: 1)
    Examples:
        >>> w = torch.empty(3, 16, 5, 5)
        >>> nn.init.dirac_(w)
        >>> w = torch.empty(3, 24, 5, 5)
        >>> nn.init.dirac_(w, 3)
    """
    dimensions = tensor.ndimension()
    if dimensions not in [3, 4, 5]:
        raise ValueError(
            "Only tensors with 3, 4, or 5 dimensions are supported")

    sizes = tensor.size()

    if sizes[0] % groups != 0:
        raise ValueError('dim 0 must be divisible by groups')

    out_chans_per_grp = sizes[0] // groups
    min_dim = min(out_chans_per_grp, sizes[1])

    with torch.no_grad():
        tensor.zero_()

        for g in range(groups):
            for d in range(min_dim):
                if dimensions == 3:  # Temporal convolution
                    tensor[g * out_chans_per_grp + d, d,
                           tensor.size(2) // 2] = 1
                elif dimensions == 4:  # Spatial convolution
                    tensor[g * out_chans_per_grp + d, d,
                           tensor.size(2) // 2,
                           tensor.size(3) // 2] = 1
                else:  # Volumetric convolution
                    tensor[g * out_chans_per_grp + d, d,
                           tensor.size(2) // 2,
                           tensor.size(3) // 2,
                           tensor.size(4) // 2] = 1
    return tensor
Beispiel #10
0
    def flatten_parameters(self) -> None:
        """Resets parameter data pointer so that they can use faster code paths.

        Right now, this works only if the module is on the GPU and cuDNN is enabled.
        Otherwise, it's a no-op.
        """
        # Short-circuits if _flat_weights is only partially instantiated
        if len(self._flat_weights) != len(self._flat_weights_names):
            return

        for w in self._flat_weights:
            if not isinstance(w, Tensor):
                return
        # Short-circuits if any tensor in self._flat_weights is not acceptable to cuDNN
        # or the tensors in _flat_weights are of different dtypes

        first_fw = self._flat_weights[0]
        dtype = first_fw.dtype
        for fw in self._flat_weights:
            if (not isinstance(fw.data, Tensor) or not (fw.data.dtype == dtype)
                    or not fw.data.is_cuda
                    or not torch.backends.cudnn.is_acceptable(fw.data)):
                return

        # If any parameters alias, we fall back to the slower, copying code path. This is
        # a sufficient check, because overlapping parameter buffers that don't completely
        # alias would break the assumptions of the uniqueness check in
        # Module.named_parameters().
        unique_data_ptrs = set(p.data_ptr() for p in self._flat_weights)
        if len(unique_data_ptrs) != len(self._flat_weights):
            return

        with torch.cuda.device_of(first_fw):
            import torch.backends.cudnn.rnn as rnn

            # Note: no_grad() is necessary since _cudnn_rnn_flatten_weight is
            # an inplace operation on self._flat_weights
            with torch.no_grad():
                if torch._use_cudnn_rnn_flatten_weight():
                    torch._cudnn_rnn_flatten_weight(
                        self._flat_weights,
                        (4 if self.bias else 2), self.input_size,
                        rnn.get_cudnn_mode(self.mode), self.hidden_size,
                        self.num_layers, self.batch_first,
                        bool(self.bidirectional))
Beispiel #11
0
def eye_(tensor):
    r"""Fills the 2-dimensional input `Tensor` with the identity
    matrix. Preserves the identity of the inputs in `Linear` layers, where as
    many inputs are preserved as possible.

    Args:
        tensor: a 2-dimensional `torch.Tensor`

    Examples:
        >>> w = torch.empty(3, 5)
        >>> nn.init.eye_(w)
    """
    if tensor.ndimension() != 2:
        raise ValueError("Only tensors with 2 dimensions are supported")

    with torch.no_grad():
        torch.eye(*tensor.shape,
                  out=tensor,
                  requires_grad=tensor.requires_grad)
    return tensor
Beispiel #12
0
def orthogonal_(tensor, gain=1):
    r"""Fills the input `Tensor` with a (semi) orthogonal matrix, as
    described in `Exact solutions to the nonlinear dynamics of learning in deep
    linear neural networks` - Saxe, A. et al. (2013). The input tensor must have
    at least 2 dimensions, and for tensors with more than 2 dimensions the
    trailing dimensions are flattened.

    Args:
        tensor: an n-dimensional `torch.Tensor`, where :math:`n \geq 2`
        gain: optional scaling factor

    Examples:
        >>> w = torch.empty(3, 5)
        >>> nn.init.orthogonal_(w)
    """
    if tensor.ndimension() < 2:
        raise ValueError(
            "Only tensors with 2 or more dimensions are supported")

    rows = tensor.size(0)
    cols = tensor.numel() // rows
    flattened = tensor.new(rows, cols).normal_(0, 1)

    if rows < cols:
        flattened.t_()

    # Compute the qr factorization
    q, r = torch.qr(flattened)
    # Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf
    d = torch.diag(r, 0)
    ph = d.sign()
    q *= ph

    if rows < cols:
        q.t_()

    with torch.no_grad():
        tensor.view_as(q).copy_(q)
        tensor.mul_(gain)
    return tensor
Beispiel #13
0
    def forward(ctx, I1, I2, nbin=100):
        with tp.no_grad():
            if hasattr(ctx, 'JH'): del ctx.JH
            nbin = tp.tensor(nbin)
            data_pair = tp.stack(I1.flatten(1), I2.flatten(1), dim={1})
            nbatch, nhist, ndata = data_pair.ishape
            indices = []
            values = []
            ctx.window = (tp.image_grid(4, 4) - 1).flatten(1).transpose(0, 1)
            for shift in ctx.window:
                # [nbatch] x {nhist} x ndata
                hist_pos = data_pair * nbin
                index = tp.clamp(
                    tp.floor(hist_pos).long() + shift, 0, nbin - 1)
                batch_idx = tp.arange(nbatch).expand_to([nbatch], {1}, ndata)
                index = tp.cat(batch_idx, index, 1)
                value = Bspline(shift.expand_to(data_pair),
                                tp.decimal(hist_pos)).prod(1)
                indices.append(index)
                values.append(value)
            # n_batch x (1 + n_hist) x (n_data x 4 ** n_hist)
            Mindices = tp.cat(indices, -1)
            # n_batch x (n_data x 4 ** n_hist)
            Mvalues = tp.cat(values, -1)
            # (1 + n_hist) x (n_batch x n_data x 4 ** n_hist)
            indices = Mindices.transpose(0, 1).flatten(1)
            # (n_batch x n_data x 4 ** n_hist)
            values = Mvalues.flatten(0)
            if tp.Device == tp.DeviceCPU: creator = torch.sparse.FloatTensor
            else: creator = torch.cuda.sparse.FloatTensor
            collected = creator(indices, values,
                                (nbatch, nbin, nbin)).to_dense()
            collected = tp.Tensor(collected, batch_dim=0)

            ctx.nbin = nbin
            ctx.Ishape = I1.shape
            ctx.data_pair = data_pair
            ctx.JH = collected / ndata
        return ctx.JH
Beispiel #14
0
 def backward(ctx, grad_output):
     with tp.no_grad():
         nbin = ctx.nbin
         data_pair = ctx.data_pair
         nbatch, nhist, ndata = data_pair.ishape
         dPdI1 = tp.zeros(ctx.Ishape)
         dPdI2 = tp.zeros(ctx.Ishape)
         for shift in ctx.window:
             # [nbatch] x {nhist} x ndata
             shift = shift.view(1, 2, 1)
             hist_pos = data_pair * nbin
             index = torch.clamp(
                 torch.floor(hist_pos).long() + shift, 0, nbin - 1)
             grad_y = grad_output[(slice(None), ) +
                                  index.split(1, 1)].squeeze(2)
             value = grad_y.gather(
                 0,
                 tp.arange(nbatch).long().unsqueeze(0).unsqueeze(-1).repeat(
                     1, 1, ndata)).view(ctx.Ishape)
             dPdI1 += value * dBspline_WRT_I1(
                 shift, tp.decimal(data_pair * nbin)).view(ctx.Ishape)
             dPdI2 += value * dBspline_WRT_I2(
                 shift, tp.decimal(data_pair * nbin)).view(ctx.Ishape)
     return dPdI1, dPdI2, None
Beispiel #15
0
def _no_grad_normal_(tensor, mean, std):
    with torch.no_grad():
        return tensor.normal_(mean, std)
Beispiel #16
0
def _no_grad_fill_(tensor, val):
    with torch.no_grad():
        return tensor.fill_(val)
Beispiel #17
0
def _no_grad_zero_(tensor):
    with torch.no_grad():
        return tensor.zero_()
Beispiel #18
0
def _no_grad_uniform_(tensor, a, b):
    with torch.no_grad():
        return tensor.uniform_(a, b)