Exemplo n.º 1
0
class SpectralNorm2d(nn.Module):
    r"""2D Spectral Norm Module as described in `"Spectral Normalization
    for Generative Adversarial Networks by Miyato et. al." <https://arxiv.org/abs/1802.05957>`_
    The spectral norm is computed using ``power iterations``.

    Computation Steps:

    .. math:: v_{t + 1} = \frac{W^T W v_t}{||W^T W v_t||} = \frac{(W^T W)^t v}{||(W^T W)^t v||}
    .. math:: u_{t + 1} = W v_t
    .. math:: v_{t + 1} = W^T u_{t + 1}
    .. math:: Norm(W) = ||W v|| = u^T W v
    .. math:: Output = \frac{W}{Norm(W)} = \frac{W}{u^T W v}

    Args:
        module (torch.nn.Module): The Module on which the Spectral Normalization needs to be
            applied.
        name (str, optional): The attribute of the ``module`` on which normalization needs to
            be performed.
        power_iterations (int, optional): Total number of iterations for the norm to converge.
            ``1`` is usually enough given the weights vary quite gradually.

    Example:
        .. code:: python

            >>> layer = SpectralNorm2d(Conv2d(3, 16, 1))
            >>> x = torch.rand(1, 3, 10, 10)
            >>> layer(x)
    """
    def __init__(self, module, name='weight', power_iterations=1):
        super(SpectralNorm2d, self).__init__()
        self.module = module
        self.name = name
        self.power_iterations = power_iterations
        w = getattr(self.module, self.name)
        height = w.data.shape[0]
        width = w.view(height, -1).data.shape[1]
        self.u = Parameter(w.data.new(height).normal_(0, 1),
                           requires_grad=False)
        self.v = Parameter(w.data.new(width).normal_(0, 1),
                           requires_grad=False)
        self.u.data = self._l2normalize(self.u.data)
        self.v.data = self._l2normalize(self.v.data)
        self.w_bar = Parameter(w.data)
        del self.module._parameters[self.name]

    def _l2normalize(self, x, eps=1e-12):
        r"""Function to calculate the ``L2 Normalized`` form of a Tensor

        Args:
            x (torch.Tensor): Tensor which needs to be normalized.
            eps (float, optional): A small value needed to avoid infinite values.

        Returns:
            Normalized form of the tensor ``x``.
        """
        return x / (torch.norm(x) + eps)

    def forward(self, *args):
        r"""Computes the output of the ``module`` and appies spectral normalization to the
        ``name`` attribute of the ``module``.

        Returns:
            The output of the ``module``.
        """
        height = self.w_bar.data.shape[0]
        for _ in range(self.power_iterations):
            self.v.data = self._l2normalize(
                torch.mv(torch.t(self.w_bar.view(height, -1)), self.u))
            self.u.data = self._l2normalize(
                torch.mv(self.w_bar.view(height, -1), self.v))
        sigma = self.u.dot(self.w_bar.view(height, -1).mv(self.v))
        setattr(self.module, self.name,
                self.w_bar / sigma.expand_as(self.w_bar))
        return self.module.forward(*args)
Exemplo n.º 2
0
class RC_CP_MiniMax(nn.Module):
    """
    Resource-Constrained channel pruning minimax:
    min_{w, s} max_{y>=0, z>=0} L(w) + <y, |w|_{ceil(s), 2}^2> + z*(resource(s) - B)
    """
    def __init__(self,
                 net_model,
                 resource_fn,
                 bncp_layers,
                 bncp_layers_dict,
                 group_size,
                 z_init=0.0,
                 y_init=0.0,
                 punit=1):
        super(RC_CP_MiniMax, self).__init__()
        self.net = net_model
        self.bncp_layers = bncp_layers
        self.bncp_layers_dict = bncp_layers_dict
        self.group_size = group_size
        n_layers = len(self.bncp_layers)
        self.s = Parameter(torch.zeros(n_layers))
        self.y = Parameter(torch.Tensor(n_layers))
        self.y.data.fill_(y_init)
        self.z = Parameter(torch.tensor(float(z_init)))
        self.resource_fn = resource_fn
        self.__least_s_norm = torch.zeros_like(self.s.data)
        self.s_ub = torch.zeros_like(self.s.data)
        assert punit >= 1
        self.punit = float(punit)
        # print("test3", self.bncp_layers)
        for i, layers in enumerate(self.bncp_layers):
            for layer in layers:
                assert layer.weight.numel() == layers[0].weight.numel()
            if self.bncp_layers_dict[layer] == 0:
                self.s_ub[i] = layers[0].in_features // self.group_size
            else:
                self.s_ub[i] = layers[0].in_features

    def ceiled_s(self):
        if self.punit > 1.0:
            return torch.cat(
                (ste_ceil(self.s[0].view(-1)),
                 ste_ceil(self.s[1:self.s.shape[0]] / self.punit) *
                 self.punit))
        else:
            return ste_ceil(self.s)

    def zloss(self, budget):
        return self.z * (self.resource_fn(self.ceiled_s().data) - budget)

    def sloss1(self, optimizer):
        if isinstance(optimizer, torch.optim.Adam):
            eps = optimizer.param_groups[0]['eps']
        else:
            eps = None
        s = self.ceiled_s()

        w_s_norm = torch.empty(1)
        for i, layers in enumerate(self.bncp_layers):
            temp = least_s_sum(
                s[i],
                weight_list_to_scores(layers,
                                      self.bncp_layers_dict,
                                      self.group_size,
                                      eps=eps).data.cpu()).view(-1)
            w_s_norm = torch.cat((w_s_norm, temp))

        w_s_norm = w_s_norm[1:w_s_norm.shape[0]]

        return self.y.data.dot(w_s_norm)

    def sloss2(self, budget):
        s = self.ceiled_s()
        rc = self.resource_fn(s)
        return rc - budget

    def get_least_s_norm(self, optimizer):
        if isinstance(optimizer, torch.optim.Adam):
            eps = optimizer.param_groups[0]['eps']
        else:
            eps = None
        res = self.__least_s_norm
        s = self.ceiled_s()
        for i, layers in enumerate(self.bncp_layers):
            scores = weight_list_to_scores(layers,
                                           self.bncp_layers_dict,
                                           self.group_size,
                                           eps=eps)
            res[i] = torch.topk(scores,
                                int(s[i].ceil().item()),
                                largest=False,
                                sorted=False)[0].sum().item()
        return res

    def yloss(self, w_optimizer):
        return self.y.dot(self.get_least_s_norm(w_optimizer))