class SpectralNorm2d(nn.Module): r"""2D Spectral Norm Module as described in `"Spectral Normalization for Generative Adversarial Networks by Miyato et. al." <https://arxiv.org/abs/1802.05957>`_ The spectral norm is computed using ``power iterations``. Computation Steps: .. math:: v_{t + 1} = \frac{W^T W v_t}{||W^T W v_t||} = \frac{(W^T W)^t v}{||(W^T W)^t v||} .. math:: u_{t + 1} = W v_t .. math:: v_{t + 1} = W^T u_{t + 1} .. math:: Norm(W) = ||W v|| = u^T W v .. math:: Output = \frac{W}{Norm(W)} = \frac{W}{u^T W v} Args: module (torch.nn.Module): The Module on which the Spectral Normalization needs to be applied. name (str, optional): The attribute of the ``module`` on which normalization needs to be performed. power_iterations (int, optional): Total number of iterations for the norm to converge. ``1`` is usually enough given the weights vary quite gradually. Example: .. code:: python >>> layer = SpectralNorm2d(Conv2d(3, 16, 1)) >>> x = torch.rand(1, 3, 10, 10) >>> layer(x) """ def __init__(self, module, name='weight', power_iterations=1): super(SpectralNorm2d, self).__init__() self.module = module self.name = name self.power_iterations = power_iterations w = getattr(self.module, self.name) height = w.data.shape[0] width = w.view(height, -1).data.shape[1] self.u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) self.v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False) self.u.data = self._l2normalize(self.u.data) self.v.data = self._l2normalize(self.v.data) self.w_bar = Parameter(w.data) del self.module._parameters[self.name] def _l2normalize(self, x, eps=1e-12): r"""Function to calculate the ``L2 Normalized`` form of a Tensor Args: x (torch.Tensor): Tensor which needs to be normalized. eps (float, optional): A small value needed to avoid infinite values. Returns: Normalized form of the tensor ``x``. """ return x / (torch.norm(x) + eps) def forward(self, *args): r"""Computes the output of the ``module`` and appies spectral normalization to the ``name`` attribute of the ``module``. Returns: The output of the ``module``. """ height = self.w_bar.data.shape[0] for _ in range(self.power_iterations): self.v.data = self._l2normalize( torch.mv(torch.t(self.w_bar.view(height, -1)), self.u)) self.u.data = self._l2normalize( torch.mv(self.w_bar.view(height, -1), self.v)) sigma = self.u.dot(self.w_bar.view(height, -1).mv(self.v)) setattr(self.module, self.name, self.w_bar / sigma.expand_as(self.w_bar)) return self.module.forward(*args)
class RC_CP_MiniMax(nn.Module): """ Resource-Constrained channel pruning minimax: min_{w, s} max_{y>=0, z>=0} L(w) + <y, |w|_{ceil(s), 2}^2> + z*(resource(s) - B) """ def __init__(self, net_model, resource_fn, bncp_layers, bncp_layers_dict, group_size, z_init=0.0, y_init=0.0, punit=1): super(RC_CP_MiniMax, self).__init__() self.net = net_model self.bncp_layers = bncp_layers self.bncp_layers_dict = bncp_layers_dict self.group_size = group_size n_layers = len(self.bncp_layers) self.s = Parameter(torch.zeros(n_layers)) self.y = Parameter(torch.Tensor(n_layers)) self.y.data.fill_(y_init) self.z = Parameter(torch.tensor(float(z_init))) self.resource_fn = resource_fn self.__least_s_norm = torch.zeros_like(self.s.data) self.s_ub = torch.zeros_like(self.s.data) assert punit >= 1 self.punit = float(punit) # print("test3", self.bncp_layers) for i, layers in enumerate(self.bncp_layers): for layer in layers: assert layer.weight.numel() == layers[0].weight.numel() if self.bncp_layers_dict[layer] == 0: self.s_ub[i] = layers[0].in_features // self.group_size else: self.s_ub[i] = layers[0].in_features def ceiled_s(self): if self.punit > 1.0: return torch.cat( (ste_ceil(self.s[0].view(-1)), ste_ceil(self.s[1:self.s.shape[0]] / self.punit) * self.punit)) else: return ste_ceil(self.s) def zloss(self, budget): return self.z * (self.resource_fn(self.ceiled_s().data) - budget) def sloss1(self, optimizer): if isinstance(optimizer, torch.optim.Adam): eps = optimizer.param_groups[0]['eps'] else: eps = None s = self.ceiled_s() w_s_norm = torch.empty(1) for i, layers in enumerate(self.bncp_layers): temp = least_s_sum( s[i], weight_list_to_scores(layers, self.bncp_layers_dict, self.group_size, eps=eps).data.cpu()).view(-1) w_s_norm = torch.cat((w_s_norm, temp)) w_s_norm = w_s_norm[1:w_s_norm.shape[0]] return self.y.data.dot(w_s_norm) def sloss2(self, budget): s = self.ceiled_s() rc = self.resource_fn(s) return rc - budget def get_least_s_norm(self, optimizer): if isinstance(optimizer, torch.optim.Adam): eps = optimizer.param_groups[0]['eps'] else: eps = None res = self.__least_s_norm s = self.ceiled_s() for i, layers in enumerate(self.bncp_layers): scores = weight_list_to_scores(layers, self.bncp_layers_dict, self.group_size, eps=eps) res[i] = torch.topk(scores, int(s[i].ceil().item()), largest=False, sorted=False)[0].sum().item() return res def yloss(self, w_optimizer): return self.y.dot(self.get_least_s_norm(w_optimizer))