class SubnetLinear(nn.Linear): # self.k is the % of weights remaining, a real number in [0,1] # self.popup_scores is a Parameter which has the same shape as self.weight # Gradients to self.weight, self.bias have been turned off. def __init__(self, in_features, out_features, bias=True): super(SubnetLinear, self).__init__(in_features, out_features, bias=True) self.popup_scores = Parameter(torch.Tensor(self.weight.shape)) nn.init.kaiming_uniform_(self.popup_scores, a=math.sqrt(5)) self.weight.requires_grad = False self.bias.requires_grad = False self.w = 0 # self.register_buffer('w', None) def set_prune_rate(self, k): self.k = k def forward(self, x): # Get the subnetwork by sorting the scores. adj = GetSubnet.apply(self.popup_scores.abs(), self.k) # Use only the subnetwork in the forward pass. self.w = self.weight * adj x = F.linear(x, self.w, self.bias) return x
class ActQuant_init(nn.Module): def __init__(self, act_bit=4, scale_coef=10.0, extern_init=False, init_model=nn.Sequential()): super(ActQuant_init, self).__init__() self.pwr_coef = 2**act_bit self.act_bit=act_bit self.scale_coef = Parameter(torch.ones(1)*scale_coef) if extern_init: param=list(init_model.parameters()) if param[0]>0.1 and param[0]<10.0: self.scale_coef=Parameter(param[0]) else: self.scale_coef=Parameter(torch.ones(1)*1.0) def forward(self, x): if self.act_bit==32: out=0.5*(x.abs() - (x-self.scale_coef.abs()).abs()+self.scale_coef.abs())/self.scale_coef.abs() else: out = 0.5*(x.abs() - (x-self.scale_coef.abs()).abs()+self.scale_coef.abs()) out = RoundFn.apply(out / self.scale_coef.abs(), self.pwr_coef) return out*2.0
class FRN(nn.Module): def __init__(self, num_features, eps=1e-6, is_eps_leanable=False): """ weight = gamma, bias = beta beta, gamma: Variables of shape [1, 1, 1, C]. if TensorFlow Variables of shape [1, C, 1, 1]. if PyTorch eps: A scalar constant or learnable variable. """ super(FRN, self).__init__() self.num_features = num_features self.init_eps = eps self.is_eps_leanable = is_eps_leanable self.weight = Parameter(torch.Tensor(num_features)) self.bias = Parameter(torch.Tensor(num_features)) if is_eps_leanable: self.eps = Parameter(torch.Tensor(1)) else: self.register_buffer('eps', torch.Tensor([eps])) self.reset_parameters() def reset_parameters(self): nn.init.ones_(self.weight) nn.init.zeros_(self.bias) if self.is_eps_leanable: nn.init.constant_(self.eps, self.init_eps) def extra_repr(self): return 'num_features={num_features}, eps={init_eps}'.format( **self.__dict__) def forward(self, x): """ 0, 1, 2, 3 -> (B, H, W, C) in TensorFlow 0, 1, 2, 3 -> (B, C, H, W) in PyTorch TensorFlow code nu2 = tf.reduce_mean(tf.square(x), axis=[1, 2], keepdims=True) x = x * tf.rsqrt(nu2 + tf.abs(eps)) # This Code include TLU function max(y, tau) return tf.maximum(gamma * x + beta, tau) """ # Compute the mean norm of activations per channel. nu2 = x.pow(2).mean(dim=[2, 3], keepdim=True) # Perform FRN. x = x * torch.rsqrt(nu2 + self.eps.abs()) # Scale and Bias x = self.weight.view(1, self.num_features, 1, 1) * x + self.bias.view( 1, self.num_features, 1, 1) # x = self.weight * x + self.bias return x
class VBLinear(nn.Module): def __init__(self, in_features, out_features, prior_prec=10, map=True): super(VBLinear, self).__init__() self.n_in = in_features self.n_out = out_features self.prior_prec = prior_prec self.map = map self.bias = nn.Parameter(th.Tensor(out_features)) self.mu_w = Parameter(th.Tensor(out_features, in_features)) self.logsig2_w = nn.Parameter(th.Tensor(out_features, in_features)) self.reset_parameters() def reset_parameters(self): # TODO: Adapt to the newest pytorch initializations stdv = 1. / math.sqrt(self.mu_w.size(1)) self.mu_w.data.normal_(0, stdv) self.logsig2_w.data.zero_().normal_(-9, 0.001) # var init via Louizos self.bias.data.zero_() def KL(self, loguniform=False): if loguniform: k1 = 0.63576 k2 = 1.87320 k3 = 1.48695 log_alpha = self.logsig2_w - 2 * th.log(self.mu_w.abs() + 1e-8) kl = -th.sum(k1 * th.sigmoid(k2 + k3 * log_alpha) - 0.5 * F.softplus(-log_alpha) - k1) else: logsig2_w = self.logsig2_w.clamp(-11, 11) kl = 0.5 * (self.prior_prec * (self.mu_w.pow(2) + logsig2_w.exp()) - logsig2_w - 1 - np.log(self.prior_prec)).sum() return kl def forward(self, input): # Sampling free forward pass only if MAP prediction and no training rounds if self.map and not self.training: return F.linear(input, self.mu_w, self.bias) else: mu_out = F.linear(input, self.mu_w, self.bias) logsig2_w = self.logsig2_w.clamp(-11, 11) s2_w = logsig2_w.exp() var_out = F.linear(input.pow(2), s2_w) + 1e-8 return mu_out + var_out.sqrt() * th.randn_like(mu_out) def __repr__(self): return self.__class__.__name__ + ' (' \ + str(self.n_in) + ' -> ' \ + str(self.n_out) + ')'
class SubnetConv(nn.Conv2d): # self.k is the % of weights remaining, a real number in [0,1] # self.popup_scores is a Parameter which has the same shape as self.weight # Gradients to self.weight, self.bias have been turned off by default. def __init__( self, in_channels, out_channels, kernel_size, stride=1, padding=1, dilation=1, groups=1, bias=True, ): super(SubnetConv, self).__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, ) self.popup_scores = Parameter(torch.Tensor(self.weight.shape)) nn.init.kaiming_uniform_(self.popup_scores, a=math.sqrt(5)) self.weight.requires_grad = False if self.bias is not None: self.bias.requires_grad = False self.w = 0 def set_prune_rate(self, k): self.k = k def forward(self, x): # Get the subnetwork by sorting the scores. adj = GetSubnet.apply(self.popup_scores.abs(), self.k) # Use only the subnetwork in the forward pass. self.w = self.weight * adj x = F.conv2d(x, self.w, self.bias, self.stride, self.padding, self.dilation, self.groups) return x
class RTML(nn.Module): def __init__(self, L=3, lamb=5): super(RTML, self).__init__() self.L = L self.N = len(att_names) self.lamb = lamb self.theta = Parameter(torch.Tensor(self.L, 300, 300)) self.alpha = Parameter(torch.Tensor(self.N, self.L+1)) # L+1 is so to parameterize # being smaller than norm lamb self.reset_parameters() self.att_emb = nn.Embedding(self.N, 300) if PREINIT: self.att_emb.weight.data = _load_vectors(att_names).cuda() else: _np_emb = np.random.randn(self.N, 300) _np_emb = _np_emb / np.square(_np_emb).sum(1)[:, None] self.att_emb.weight.data = torch.FloatTensor(_np_emb).cuda() def reset_parameters(self): for weight in [self.theta, self.alpha]: stdv = 1. / math.sqrt(weight.size(1)) weight.data.uniform_(-stdv, stdv) def forward(self, word_embs): alpha_norm = self.alpha.abs().sum(1) alpha_constrained = self.lamb * self.alpha / alpha_norm.expand_as(self.alpha) R_flat = alpha_constrained[:, :-1] @ self.theta.view(self.L, -1) R = R_flat.view(self.N, 300, 300) s = 0 preds = [] for i, att_size in enumerate(dom_sizes): e = s + att_size att_embs = self.att_emb.weight[s:e].t() s = e p1 = word_embs @ R[i] p2 = p1 @ att_embs preds.append(p2) preds = torch.cat(preds, 1) return preds
class Conv2dDPQ(nn.Conv2d): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, qmin=1e-3, qmax=100, dmin=1e-5, dmax=10, bias=True, sign=True, wbits=4, abits=4, mode=Qmodes.layer_wise): """ :param d_init: the inital quantization stepsize (alpha) :param mode: Qmodes.layer_wise or Qmodes.kernel_wise :param xmax_init: the quantization range for whole weights """ super(Conv2dDPQ, self).__init__(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) self.qmin = qmin self.qmax = qmax self.dmin = dmin self.dmax = dmax self.q_mode = mode self.sign = sign self.nbits = wbits self.act_dpq = ActDPQ(signed=False, nbits=abits) self.alpha = Parameter(torch.Tensor(1)) self.xmax = Parameter(torch.Tensor(1)) self.weight.requires_grad_(True) if bias: self.bias.requires_grad_(True) self.register_buffer('init_state', torch.zeros(1)) def get_nbits(self): abits = self.act_dpq.get_nbits() xmax = self.xmax.abs().item() alpha = self.alpha.abs().item() if self.sign: nbits = math.ceil(math.log(xmax / alpha + 1) / math.log(2) + 1) else: nbits = math.cell(math.log(xmax / alpha + 1) / math.log(2)) self.nbits = nbits return abits, nbits def get_quan_filters(self, filters): if self.training and self.init_state == 0: Qp = 2**(self.nbits - 1) - 1 self.xmax.data.copy_(filters.abs().max()) self.alpha.data.copy_(self.xmax / Qp) # self.alpha[self.index].data.copy_(2 * filters.abs().mean() / math.sqrt(Qp)) # self.xmax[self.index].data.copy_(self.alpha[self.index] * Qp) self.init_state.fill_(1) Qp = (self.xmax.detach() / self.alpha.detach()).abs().item() g = 1.0 / math.sqrt(filters.numel() * Qp) alpha = grad_scale(self.alpha, g) xmax = grad_scale(self.xmax, g) w = F.hardtanh(filters / xmax.abs(), -1, 1) * xmax.abs() w = w / alpha.abs() wq = round_pass(w) * alpha.abs() return wq def forward(self, x): if self.act_dpq is not None: x = self.act_dpq(x) wq = self.get_quan_filters(self.weight) return F.conv2d(x, wq, self.bias, self.stride, self.padding, self.dilation, self.groups)
class SparseTensor(nn.Module): def __init__(self, tensor_size, initial_sparsity, sub_kernel_granularity=4): super(SparseTensor, self).__init__() self.s_tensor = Parameter(torch.Tensor(torch.Size(tensor_size))) self.initial_sparsity = initial_sparsity self.sub_kernel_granularity = sub_kernel_granularity assert self.s_tensor.dim() == 2 or self.s_tensor.dim( ) == 4, "can only do 2D or 4D sparse tensors" trailing_dimensions = [1] * (4 - sub_kernel_granularity) self.register_buffer( 'mask', torch.Tensor(*(tensor_size[:sub_kernel_granularity]))) self.normalize_coeff = np.prod( tensor_size[sub_kernel_granularity:]).item() self.conv_tensor = False if self.s_tensor.dim() == 2 else True self.mask.zero_() flat_mask = self.mask.view(-1) indices = np.arange(flat_mask.size(0)) np.random.shuffle(indices) flat_mask[indices[:int((1 - initial_sparsity) * flat_mask.size(0) + 0.1)]] = 1 self.grown_indices = None self.init_parameters() self.reinitialize_unused() self.tensor_sign = torch.sign(self.s_tensor.data.view(-1)) def reinitialize_unused(self, reinitialize_unused_to_zero=True): unused_positions = (self.mask < 0.5) if reinitialize_unused_to_zero: self.s_tensor.data[unused_positions] = torch.zeros( self.s_tensor.data[unused_positions].size()).to( self.s_tensor.device) else: if self.conv_tensor: n = self.s_tensor.size(0) * self.s_tensor.size( 2) * self.s_tensor.size(3) self.s_tensor.data[unused_positions] = torch.zeros( self.s_tensor.data[unused_positions].size()).normal_( 0, math.sqrt(2. / n)).to(self.s_tensor.device) else: stdv = 1. / math.sqrt(self.s_tensor.size(1)) self.s_tensor.data[unused_positions] = torch.zeros( self.s_tensor.data[unused_positions].size()).normal_( 0, stdv).to(self.s_tensor.device) def init_parameters(self): stdv = 1 / math.sqrt(np.prod(self.s_tensor.size()[1:])) self.s_tensor.data.uniform_(-stdv, stdv) def prune_sign_change(self, reinitialize_unused_to_zero=True, enable_print=False): W_flat = self.s_tensor.data.view(-1) new_tensor_sign = torch.sign(W_flat) mask_flat = self.mask.view(-1) mask_indices = torch.nonzero(mask_flat > 0.5).view(-1) sign_change_indices = mask_indices[( (new_tensor_sign[mask_indices] * self.tensor_sign[mask_indices].to(new_tensor_sign.device)) < -0.5).nonzero().view(-1)] mask_flat[sign_change_indices] = 0 self.reinitialize_unused(reinitialize_unused_to_zero) cutoff = sign_change_indices.numel() if enable_print: print('pruned {} connections'.format(cutoff)) if self.grown_indices is not None and enable_print: overlap = np.intersect1d(sign_change_indices.cpu().numpy(), self.grown_indices.cpu().numpy()) print('pruned {} ({} %) just grown weights'.format( overlap.size, overlap.size * 100.0 / self.grown_indices.size(0) if self.grown_indices.size(0) > 0 else 0.0)) self.tensor_sign = new_tensor_sign return sign_change_indices def prune_small_connections(self, prune_fraction, reinitialize_unused_to_zero=True): if self.conv_tensor and self.sub_kernel_granularity < 4: W_flat = self.s_tensor.abs().sum( list(np.arange(self.sub_kernel_granularity, 4))).view(-1) / self.normalize_coeff else: W_flat = self.s_tensor.data.view(-1) mask_flat = self.mask.view(-1) mask_indices = torch.nonzero(mask_flat > 0.5).view(-1) W_masked = W_flat[mask_indices] sorted_W_indices = torch.sort(torch.abs(W_masked))[1] cutoff = int(prune_fraction * W_masked.numel()) + 1 mask_flat[mask_indices[sorted_W_indices[:cutoff]]] = 0 self.reinitialize_unused(reinitialize_unused_to_zero) # print('pruned {} connections'.format(cutoff)) # if self.grown_indices is not None: # overlap = np.intersect1d(mask_indices[sorted_W_indices[:cutoff]].cpu().numpy(),self.grown_indices.cpu().numpy()) #print('pruned {} ({} %) just grown weights'.format(overlap.size,overlap.size * 100.0 / self.grown_indices.size(0))) return mask_indices[sorted_W_indices[:cutoff]] def prune_threshold(self, threshold, reinitialize_unused_to_zero=True): if self.conv_tensor and self.sub_kernel_granularity < 4: W_flat = self.s_tensor.abs().sum( list(np.arange(self.sub_kernel_granularity, 4))).view(-1) / self.normalize_coeff else: W_flat = self.s_tensor.data.view(-1) mask_flat = self.mask.view(-1) mask_indices = torch.nonzero(mask_flat > 0.5).view(-1) W_masked = W_flat[mask_indices] prune_indices = (W_masked.abs() < threshold).nonzero().view(-1) if mask_indices.size(0) == prune_indices.size(0): print('removing all. keeping one') prune_indices = prune_indices[1:] mask_flat[mask_indices[prune_indices]] = 0 # if mask_indices.numel() > 0 : # print('pruned {}/{}({:.2f}) connections'.format(prune_indices.numel(),mask_indices.numel(),prune_indices.numel()/mask_indices.numel())) # if self.grown_indices is not None and self.grown_indices.size(0) != 0 : # overlap = np.intersect1d(mask_indices[prune_indices].cpu().numpy(),self.grown_indices.cpu().numpy()) # print('pruned {} ({} %) just grown weights'.format(overlap.size,overlap.size * 100.0 / self.grown_indices.size(0))) self.reinitialize_unused(reinitialize_unused_to_zero) return mask_indices[prune_indices] def grow_random(self, grow_fraction, pruned_indices=None, enable_print=False, n_to_add=None): mask_flat = self.mask.view(-1) mask_zero_indices = torch.nonzero(mask_flat < 0.5).view(-1) if pruned_indices is not None: cutoff = pruned_indices.size(0) mask_zero_indices = torch.Tensor( np.setdiff1d(mask_zero_indices.cpu().numpy(), pruned_indices.cpu().numpy())).long().to( mask_zero_indices.device) else: cutoff = int(grow_fraction * mask_zero_indices.size(0)) if n_to_add is not None: cutoff = n_to_add if mask_zero_indices.numel() < cutoff: print('******no place to grow {} connections, growing {} instead'. format(cutoff, mask_zero_indices.numel())) cutoff = mask_zero_indices.numel() if enable_print: print('grown {} connections'.format(cutoff)) self.grown_indices = mask_zero_indices[torch.randperm( mask_zero_indices.numel())][:cutoff] mask_flat[self.grown_indices] = 1 return cutoff def get_sparsity(self): active_elements = self.mask.sum() * np.prod( self.s_tensor.size()[self.sub_kernel_granularity:]).item() return (active_elements, 1 - active_elements / self.s_tensor.numel()) def forward(self): if self.conv_tensor: return self.mask.view( *(self.mask.size() + (1, ) * (4 - self.sub_kernel_granularity))) * self.s_tensor else: return self.mask * self.s_tensor def extra_repr(self): return 'full tensor size : {} , sparsity mask : {} , sub kernel granularity : {}'.format( self.s_tensor.size(), self.get_sparsity(), self.sub_kernel_granularity)
class group_relaxed_L0Dense(Module): """Implementation of TFL regularization for the input units of a fully connected layer""" def __init__(self, in_features, out_features, bias=True, lamba=1., beta=4., weight_decay=1., **kwargs): """ :param in_features: input dimensionality :param out_features: output dimensionality :param bias: whether we use bias :param lamba: strength of the TF1 regularization """ super(group_relaxed_L0Dense, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = Parameter(torch.Tensor(in_features, out_features)) self.u = torch.rand(in_features, out_features) self.u = self.u.to('cuda') if bias: self.bias = Parameter(torch.Tensor(out_features)) else: self.register_parameter('bias', None) self.lamba = lamba self.beta = beta self.weight_decay = weight_decay self.floatTensor = torch.FloatTensor if not torch.cuda.is_available( ) else torch.cuda.FloatTensor self.reset_parameters() print(self) def reset_parameters(self): init.kaiming_normal(self.weight, mode='fan_out') if self.bias is not None: self.bias.data.normal_(0, 1e-2) def constrain_parameters(self, **kwargs): #self.weight.data = F.normalize(self.weight.data, p=2, dim=1) m = Hardshrink((2 * self.lamba / self.beta)**(1 / 2)) self.u.data = m(self.weight.data) def grow_beta(self, growth_factor): self.beta = self.beta * growth_factor def _reg_w(self, **kwargs): logpw = -self.beta * torch.sum( 0.5 * self.weight.add(-self.u).pow(2)) - self.lamba * np.sqrt( self.out_features) * torch.sum( torch.pow(torch.sum(self.weight.pow(2), 1), 0.5)) logpb = 0 if self.bias is not None: logpb = -torch.sum(self.weight_decay * .5 * (self.bias.pow(2))) return logpw + logpb def regularization(self): return self._reg_w() def count_zero_u(self): total = np.prod(self.u.size()) zero = total - self.u.nonzero().size(0) return zero def count_zero_w(self): return torch.sum((self.weight.abs() < 1e-5).int()).item() def count_weight(self): return np.prod(self.u.size()) def count_active_neuron(self): return torch.sum( torch.sum(self.weight.abs() / self.out_features, 1) > 1e-5).item() def count_total_neuron(self): return self.in_features def count_expected_flops_and_l0(self): ppos = torch.sum(self.weight.abs() > 0.000001).item() expected_flops = (2 * ppos - 1) * self.out_features expected_l0 = ppos * self.out_features if self.bias is not None: expected_flops += self.out_features expected_l0 += self.out_features return expected_flops, expected_l0 def forward(self, input): output = input.mm(self.weight) if self.bias is not None: output.add_(self.bias.view(1, self.out_features).expand_as(output)) return output def __repr__(self): return self.__class__.__name__+' (' \ + str(self.in_features) + ' -> ' \ + str(self.out_features) + ', lambda: ' \ + str(self.lamba) + ')'
class MAPConv2d(Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, weight_decay=1., **kwargs): super(MAPConv2d, self).__init__() self.weight_decay = weight_decay self.floatTensor = torch.FloatTensor if not torch.cuda.is_available( ) else torch.cuda.FloatTensor self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = pair(kernel_size) self.stride = pair(stride) self.padding = pair(padding) self.dilation = pair(dilation) self.output_padding = pair(0) self.groups = groups self.weight = Parameter( torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) if bias: self.bias = Parameter(torch.Tensor(out_channels)) else: self.register_parameter('bias', None) self.reset_parameters() self.input_shape = None print(self) def reset_parameters(self): init.kaiming_normal(self.weight, mode='fan_in') if self.bias is not None: self.bias.data.normal_(0, 1e-2) def constrain_parameters(self, thres_std=1.): pass def _reg_w(self, **kwargs): logpw = -torch.sum(self.weight_decay * .5 * (self.weight.pow(2))) logpb = 0 if self.bias is not None: logpb = -torch.sum(self.weight_decay * .5 * (self.bias.pow(2))) return logpw + logpb def regularization(self): return self._reg_w() def count_zero_w(self): return torch.sum((self.weight.abs() < 1e-5).int()).item() def count_zero_u(self): return 0 def count_weight(self): return np.prod(self.weight.size()) def count_active_neuron(self): return torch.sum((torch.sum(self.weight.abs(), 3).sum(2).sum(1) / (self.in_channels * self.kernel_size[0] * self.kernel_size[1])) > 1e-5).item() def count_total_neuron(self): return self.out_channels def count_expected_flops_and_l0(self): ppos = self.out_channels n = self.kernel_size[0] * self.kernel_size[ 1] * self.in_channels # vector_length flops_per_instance = n + (n - 1 ) # (n: multiplications and n-1: additions) num_instances_per_filter = ( (self.input_shape[1] - self.kernel_size[0] + 2 * self.padding[0]) / self.stride[0]) + 1 # for rows num_instances_per_filter *= ( (self.input_shape[2] - self.kernel_size[1] + 2 * self.padding[1]) / self.stride[1]) + 1 # multiplying with cols flops_per_filter = num_instances_per_filter * flops_per_instance expected_flops = flops_per_filter * ppos # multiply with number of filters expected_l0 = n * ppos if self.bias is not None: # since the gate is applied to the output we also reduce the bias computation expected_flops += num_instances_per_filter * ppos expected_l0 += ppos return expected_flops, expected_l0 def forward(self, input_): if self.input_shape is None: self.input_shape = input_.size() output = F.conv2d(input_, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) return output def __repr__(self): s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size} ' ', stride={stride}, weight_decay={weight_decay}') if self.padding != (0, ) * len(self.padding): s += ', padding={padding}' if self.dilation != (1, ) * len(self.dilation): s += ', dilation={dilation}' if self.output_padding != (0, ) * len(self.output_padding): s += ', output_padding={output_padding}' if self.groups != 1: s += ', groups={groups}' if self.bias is None: s += ', bias=False' s += ')' return s.format(name=self.__class__.__name__, **self.__dict__)
class MAPDense(Module): def __init__(self, in_features, out_features, bias=True, weight_decay=1., **kwargs): super(MAPDense, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = Parameter(torch.Tensor(in_features, out_features)) self.weight_decay = weight_decay if bias: self.bias = Parameter(torch.Tensor(out_features)) else: self.register_parameter('bias', None) self.floatTensor = torch.FloatTensor if not torch.cuda.is_available( ) else torch.cuda.FloatTensor self.reset_parameters() print(self) def reset_parameters(self): init.kaiming_normal(self.weight, mode='fan_out') if self.bias is not None: self.bias.data.normal_(0, 1e-2) def constrain_parameters(self, **kwargs): pass def _reg_w(self, **kwargs): logpw = -torch.sum(self.weight_decay * .5 * (self.weight.pow(2))) logpb = 0 if self.bias is not None: logpb = -torch.sum(self.weight_decay * .5 * (self.bias.pow(2))) return logpw + logpb def regularization(self): return self._reg_w() def count_zero_w(self): return torch.sum((self.weight.abs() < 1e-5).int()).item() def count_zero_u(self): return 0 def count_weight(self): return np.prod(self.weight.size()) def count_active_neuron(self): return torch.sum( torch.sum(self.weight.abs() / self.out_features, 1) > 1e-5).item() def count_total_neuron(self): return self.in_features def count_expected_flops_and_l0(self): # dim_in multiplications and dim_in - 1 additions for each output neuron for the weights # + the bias addition for each neuron # total_flops = (2 * in_features - 1) * out_features + out_features expected_flops = (2 * self.in_features - 1) * self.out_features expected_l0 = self.in_features * self.out_features if self.bias is not None: expected_flops += self.out_features expected_l0 += self.out_features return expected_flops, expected_l0 def forward(self, input): output = input.mm(self.weight) if self.bias is not None: output.add_(self.bias.view(1, self.out_features).expand_as(output)) return output def __repr__(self): return self.__class__.__name__ + ' (' \ + str(self.in_features) + ' -> ' \ + str(self.out_features) + ', weight_decay: ' \ + str(self.weight_decay) + ')'
class SubnetConv(nn.Conv2d): # self.k is the % of weights remaining, a real number in [0,1] # self.popup_scores is a Parameter which has the same shape as self.weight # Gradients to self.weight, self.bias have been turned off by default. def __init__( self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, ): super(SubnetConv, self).__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, ) # Weight pruning # self.popup_scores = Parameter(torch.Tensor(self.weight.shape)) # Channel Finetuning or Resume Pruning # self.popup_scores = Parameter(torch.Tensor(torch.Size([1,self.weight.shape[1],1,1]))) # Channel Pruning self.popup_scores = Parameter( torch.Tensor(torch.Size([self.weight.shape[0], 1, 1, 1]))) nn.init.kaiming_uniform_(self.popup_scores, a=math.sqrt(5)) self.weight.requires_grad = False if self.bias is not None: self.bias.requires_grad = False self.w = 0 def set_prune_rate(self, k): self.k = k def forward(self, x): """ Unstructured comparison remaining_weights = int(self.k * len(self.weight.flatten())) idx_same_top_weights_scores = list( set(torch.topk(self.weight.abs().flatten(), remaining_weights).indices.tolist()).intersection( set(torch.topk(self.popup_scores.abs().flatten(), remaining_weights).indices.tolist()))) num_remaining_weights = len(idx_same_top_weights_scores) print( f"SubnetConv: Number of same indices for scores and weights that are left after pruning: " f"{num_remaining_weights}. These are {float(num_remaining_weights / remaining_weights)} percent of the " f"weights kept.") """ """ Structured Comparison remaining_filters = int(self.k * self.weight.shape[0]) idx_same_top_weights_scores = list(set( torch.topk(torch.linalg.norm(self.weight.abs().reshape(self.weight.shape[0], -1), 1, dim=1), remaining_filters).indices.tolist()).intersection( torch.topk(torch.linalg.norm(self.popup_scores.abs().reshape(self.popup_scores.shape[0], -1), 1, dim=1), remaining_filters).indices.tolist())) num_remaining_filters = len(idx_same_top_weights_scores) print( f"SubnetConv: Number of same indices for filters that are left after pruning using scores or weights : " f"{num_remaining_filters}. These are {float(num_remaining_filters / remaining_filters)} percent of the " f"filters kept.") """ """ Channel Prune VGG16 global conv_nr if conv_nr == 13: conv_nr = 1 else: conv_nr += 1 # Get the subnetwork by sorting the scores. mask_conv_50 = [1.0, 1.0, 0.984375, 1.0, 1.0, 0.98828125, 0.98046875, 1.0, 0.96875, 0.359375, 0.099609375, 0.1015625, 0.099609375] mask_conv_10 = [1.0, 0.5, 0.46875, 0.4921875, 0.484375, 0.4765625, 0.5, 0.5, 0.48242188, 0.05078125, 0.0234375, 0.015625, 0.015625] k = mask_conv_10[conv_nr-1] if conv_nr == 1: adj = GetSubnet.apply(self.popup_scores.abs(), 1) else: adj = GetSubnet.apply(self.popup_scores.abs(), self.k) """ global conv_nr if conv_nr == 28: conv_nr = 1 else: conv_nr += 1 """ mask_wrn_50 = [1, 0.5, 0.171875, 0.5, 0.5625, 0.5, 0.359375, 0.40625, 0.375, 0.1875, 0.390625, 0.4453125, 0.390625, 0.328125, 0.171875, 0.4765625, 0.3046875, 0.140625, 0.2265625,0.640625, 0.49609375, 0.640625, 0.6015625, 0.6796875, 0.46875, 0.52734375, 0.50390625, 0.48825125] # Add mask for 0.1 Channel pruning here mask_wrn_10 = None k = mask_wrn_50[conv_nr-1] adj = GetSubnet.apply(self.popup_scores.abs(), k) """ if conv_nr == 1: adj = GetSubnet.apply(self.popup_scores.abs(), 1) else: adj = GetSubnet.apply(self.popup_scores.abs(), self.k) # Use only the subnetwork in the forward pass. self.w = self.weight * adj x = F.conv2d(x, self.w, self.bias, self.stride, self.padding, self.dilation, self.groups) return x
class DenseLinear(nn.Module): __constants__ = ['in_features', 'out_features'] def __init__(self, in_features, out_features, use_bias=True, use_mask=True, **kwargs): super(DenseLinear, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = Parameter(torch.Tensor(out_features, in_features)) if use_bias: self.bias = Parameter(torch.Tensor(out_features)) else: self.register_parameter('bias', None) self.reset_parameters(**kwargs) # self._initial_weight = self.weight.data.clone() # self._initial_bias = self.bias.data.clone() if use_bias else None self.use_mask = use_mask self.mask = torch.ones_like(self.weight, dtype=torch.bool) def reset_parameters(self, **kwargs): if len(kwargs.keys()) == 0: # default init, see https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear init.kaiming_uniform_(self.weight, a=math.sqrt(5)) else: init.kaiming_uniform_(self.weight, **kwargs) if self.bias is not None: # default init, see https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) bound = 1 / math.sqrt(fan_in) init.uniform_(self.bias, -bound, bound) def forward(self, inp: torch.Tensor): masked_weight = self.weight * self.mask if self.use_mask else self.weight return nn.functional.linear(inp, masked_weight, self.bias) def extra_repr(self): return 'in_features={}, out_features={}, bias={}'.format( self.in_features, self.out_features, self.bias is not None) def prune_by_threshold(self, thr): self.mask *= (self.weight.abs() >= thr) def prune_by_rank(self, rank): if rank == 0: return weight_val = self.weight[self.mask == 1.] sorted_abs_weight = weight_val.abs().sort()[0] thr = sorted_abs_weight[rank] self.prune_by_threshold(thr) def prune_by_pct(self, pct): prune_idx = int(self.num_weight * pct) self.prune_by_rank(prune_idx) def retain_by_threshold(self, thr): self.mask *= (self.weight.abs() >= thr) def retain_by_rank(self, rank): weights_val = self.weight[self.mask == 1.] sorted_abs_weights = weights_val.abs().sort(descending=True)[0] thr = sorted_abs_weights[rank] self.retain_by_threshold(thr) def random_prune_by_pct(self, pct): prune_idx = int(self.num_weight * pct) rand = torch.rand(size=self.mask.size(), device=self.mask.device) rand_val = rand[self.mask == 1] sorted_abs_rand = rand_val.sort()[0] thr = sorted_abs_rand[prune_idx] self.mask *= (rand >= thr) # def reinitialize(self): # self.weight = Parameter(self._initial_weight) # if self._initial_bias is not None: # self.bias = Parameter(self._initial_bias) def to_sparse(self, transpose=False) -> SparseLinear: """ by chance, some entries with mask = 1 can have a 0 value. Thus, the to_sparse methods give a different size there's no efficient way to solve it yet """ sparse_bias = None if self.bias is None else self.bias.reshape((-1, 1)) sparse_linear = SparseLinear((self.weight * self.mask).to_sparse(), sparse_bias, self.mask) if transpose: sparse_linear.transpose = True return sparse_linear def move_data(self, device: torch.device): self.mask = self.mask.to(device) def to(self, *args, **kwargs): device = torch._C._nn._parse_to(*args, **kwargs)[0] if device is not None: self.move_data(device) return super(DenseLinear, self).to(*args, **kwargs) @property def num_weight(self) -> int: return self.mask.sum().item()
class FilterStripe(nn.Conv2d): def __init__(self, in_channels, out_channels, kernel_size=3, stride=1): super(FilterStripe, self).__init__(in_channels, out_channels, kernel_size, stride, kernel_size // 2, groups=1, bias=False) self.BrokenTarget = None self.FilterSkeleton = Parameter(torch.ones(self.out_channels, self.kernel_size[0], self.kernel_size[1]), requires_grad=True) def forward(self, x): if self.BrokenTarget is not None: out = torch.zeros(x.shape[0], self.FilterSkeleton.shape[0], int(np.ceil(x.shape[2] / self.stride[0])), int(np.ceil(x.shape[3] / self.stride[1]))) if x.is_cuda: out = out.cuda() x = F.conv2d(x, self.weight) l, h = 0, 0 for i in range(self.BrokenTarget.shape[0]): for j in range(self.BrokenTarget.shape[1]): h += self.FilterSkeleton[:, i, j].sum().item() out[:, self.FilterSkeleton[:, i, j]] += self.shift( x[:, l:h], i, j)[:, :, ::self.stride[0], ::self.stride[1]] l += self.FilterSkeleton[:, i, j].sum().item() return out else: return F.conv2d(x, self.weight * self.FilterSkeleton.unsqueeze(1), stride=self.stride, padding=self.padding, groups=self.groups) def prune_in(self, in_mask=None): self.weight = Parameter(self.weight[:, in_mask]) self.in_channels = in_mask.sum().item() def prune_out(self, threshold): out_mask = (self.FilterSkeleton.abs() > threshold).sum(dim=(1, 2)) != 0 if out_mask.sum() == 0: out_mask[0] = True self.weight = Parameter(self.weight[out_mask]) self.FilterSkeleton = Parameter(self.FilterSkeleton[out_mask], requires_grad=True) self.out_channels = out_mask.sum().item() return out_mask def _break(self, threshold): self.weight = Parameter(self.weight * self.FilterSkeleton.unsqueeze(1)) self.FilterSkeleton = Parameter( (self.FilterSkeleton.abs() > threshold), requires_grad=False) if self.FilterSkeleton.sum() == 0: self.FilterSkeleton.data[0][0][0] = True self.out_channels = self.FilterSkeleton.sum().item() self.BrokenTarget = self.FilterSkeleton.sum(dim=0) self.kernel_size = (1, 1) self.weight = Parameter( self.weight.permute(2, 3, 0, 1).reshape(-1, self.in_channels, 1, 1)[self.FilterSkeleton.permute( 1, 2, 0).reshape(-1)]) def update_skeleton(self, sr, threshold): self.FilterSkeleton.grad.data.add_( sr * torch.sign(self.FilterSkeleton.data)) mask = self.FilterSkeleton.data.abs() > threshold self.FilterSkeleton.data.mul_(mask) self.FilterSkeleton.grad.data.mul_(mask) out_mask = mask.sum(dim=(1, 2)) != 0 return out_mask def shift(self, x, i, j): return F.pad(x, (self.BrokenTarget.shape[0] // 2 - j, j - self.BrokenTarget.shape[0] // 2, self.BrokenTarget.shape[0] // 2 - i, i - self.BrokenTarget.shape[1] // 2), 'constant', 0) def extra_repr(self): s = ( '{BrokenTarget},{in_channels}, {out_channels}, kernel_size={kernel_size}' ', stride={stride}') return s.format(**self.__dict__)
class CLTLinear(nn.Module): def __init__(self, in_features, out_features, prior_prec=10, relu_act=True, elu_act=False): super(CLTLinear, self).__init__() self.n_in = in_features self.n_out = out_features self.prior_prec = prior_prec assert not ( relu_act and elu_act ) # A single layer can only do either relu or elu activation self.relu_act = relu_act self.elu_act = elu_act self.bias = nn.Parameter(th.Tensor(out_features)) self.mu_w = Parameter(th.Tensor(out_features, in_features)) self.logsig2_w = nn.Parameter(th.Tensor(out_features, in_features)) self.reset_parameters() def reset_parameters(self): # TODO: Adapt to the newest pytorch initializations stdv = 1. / math.sqrt(self.mu_w.size(1)) self.mu_w.data.normal_(0, stdv) self.logsig2_w.data.zero_().normal_(-9, 0.001) self.bias.data.zero_() def KL(self, loguniform=False): if loguniform: k1 = 0.63576 k2 = 1.87320 k3 = 1.48695 log_alpha = self.logsig2_w - 2 * th.log(self.mu_w.abs() + 1e-8) kl = -th.sum(k1 * F.sigmoid(k2 + k3 * log_alpha) - 0.5 * F.softplus(-log_alpha) - k1) else: logsig2_w = self.logsig2_w.clamp(-11, 11) kl = 0.5 * (self.prior_prec * (self.mu_w.pow(2) + logsig2_w.exp()) - logsig2_w - 1 - np.log(self.prior_prec)).sum() return kl def cdf(self, x, mu=0., sig=1.): return 0.5 * (1 + th.erf((x - mu) / (sig * math.sqrt(2)))) def pdf(self, x, mu=0., sig=1.): return (1 / (math.sqrt(2 * math.pi) * sig)) * th.exp(-0.5 * ( (x - mu) / sig).pow(2)) def relu_moments(self, mu, sig): alpha = mu / sig cdf = self.cdf(alpha) pdf = self.pdf(alpha) relu_mean = mu * cdf + sig * pdf relu_var = (sig.pow(2) + mu.pow(2)) * cdf + mu * sig * pdf - relu_mean.pow(2) relu_var.clamp_(1e-8) # Avoid negative variance due to numerics return relu_mean, relu_var def elu_moments_orig(self, mu, sig): # the original method without simplifications sig2 = sig.pow(2) elu_mean = th.exp(mu.clamp_max(10) + sig2 / 2) * self.cdf( -(mu + sig2) / sig) - self.cdf(-mu / sig) elu_mean += mu * self.cdf(mu / sig) + sig * self.pdf(mu / sig) elu_var = th.exp(2 * mu.clamp_max(10) + 2 * sig2) * self.cdf( -(mu + 2 * sig2) / sig) elu_var += -2 * th.exp(mu.clamp_max(10) + sig2 / 2) * self.cdf( -(mu + sig2) / sig) elu_var += self.cdf(-mu / sig) elu_var += (sig2 + mu.pow(2)) * self.cdf( mu / sig) + mu * sig * self.pdf(mu / sig) elu_var += -elu_mean.pow(2) elu_var.clamp_min_(1e-8) # Avoid negative variance due to numerics return elu_mean, elu_var def elu_moments(self, mu, sig): # NOTE: For now it is without alpha or the selu extension! # Note: Takes roughly 3x as much time as the relu # Clamp the mus to avoid problems in the expectation sig2 = sig.pow(2) alpha = mu / sig cdf_alpha = self.cdf(alpha) pdf_alpha = self.pdf(alpha) cdf_malpha = 1 - cdf_alpha cdf_malphamsig = self.cdf(-alpha - sig) elu_mean = th.exp(mu.clamp_max(10) + sig2 / 2) * cdf_malphamsig - cdf_malpha elu_mean += mu * cdf_alpha + sig * pdf_alpha elu_var = th.exp(2 * mu.clamp_max(10) + 2 * sig2) * self.cdf(-alpha - 2 * sig) elu_var += -2 * th.exp(mu.clamp_max(10) + sig2 / 2) * cdf_malphamsig elu_var += cdf_malpha elu_var += (sig2 + mu.pow(2)) * cdf_alpha + mu * sig * pdf_alpha elu_var += -elu_mean.pow(2) elu_var.clamp_min_(1e-8) # Avoid negative variance due to numerics return elu_mean, elu_var def forward(self, mu_inp, var_inp=None): s2_w = self.logsig2_w.exp() mu_out = F.linear(mu_inp, self.mu_w, self.bias) if var_inp is None: var_out = F.linear(mu_inp.pow(2), s2_w) + 1e-8 else: var_out = F.linear(var_inp + mu_inp.pow(2), s2_w) + F.linear( var_inp, self.mu_w.pow(2)) + 1e-8 if self.relu_act: mu_out, var_out = self.relu_moments(mu_out, var_out.sqrt()) if self.elu_act: mu_out, var_out = self.elu_moments(mu_out, var_out.sqrt()) return mu_out, var_out # + 1e-8 Already provided in the moment computation def __repr__(self): return self.__class__.__name__ + ' (' \ + str(self.n_in) + ' -> ' \ + str(self.n_out) \ + f", activation={self.relu_act or self.elu_act}" \ + f" ({'relu' if self.relu_act else ('elu' if self.elu_act else '')}))"
class Linear(nn.Module): __constants__ = ['in_features', 'out_features'] in_features: int out_features: int weight: Tensor def __init__(self, in_features: int, out_features: int, bias: bool = True, activation="ReLU", hidden_dim=None, hidden_activation="ReLU") -> None: super(Linear, self).__init__() self.in_features = in_features self.out_features = out_features self.hidden_dim = hidden_dim self.hidden_activation = hidden_activation if hidden_dim is None: self.dims = vector(in_features, out_features) self.weight = Parameter(torch.zeros(out_features, in_features)) if bias: self.bias = Parameter(torch.zeros(out_features)) else: self.register_parameter('bias', None) self.activation = get_activation_layer(activation) else: self.dims = vector(in_features, *vector(hidden_dim), out_features) self.weight = nn.ParameterList(self.dims.map_k(lambda in_dim, out_dim: Parameter(torch.zeros(out_dim, in_dim)), 2)) if bias: self.bias = nn.ParameterList(self.dims.map_k(lambda in_dim, out_dim: Parameter(torch.zeros(out_dim)), 2)) else: self.register_parameter('bias', None) self.activation = vector(get_activation_layer(hidden_activation) for _ in range(len(hidden_dim))) self.activation.append(get_activation_layer(activation)) self.reset_parameters() def reset_parameters(self) -> None: if self.hidden_dim is None: if isinstance(self.activation, torch.nn.ReLU) or self.activation == torch.relu: init.kaiming_normal_(self.weight, a=0, mode='fan_in', nonlinearity='relu') else: init.xavier_normal_(self.weight) else: for a, w in zip(self.activation, self.weight): if isinstance(a, torch.nn.ReLU) or a == torch.relu: init.kaiming_normal_(w, a=0, mode='fan_in', nonlinearity='relu') else: init.xavier_normal_(w) def forward(self, input: Tensor) -> Tensor: if self.hidden_dim is None: if self.activation is None: return F.linear(input, self.weight, self.bias) else: return self.activation(F.linear(input, self.weight, self.bias)) else: h = input if self.bias is None: for w, a in zip(self.weight, self.activation): h = a(F.linear(h, w, None)) else: for w, b, a in zip(self.weight, self.bias, self.activation): h = a(F.linear(h, w, b)) return h def extra_repr(self) -> str: if self.activation is None: return 'in_features={}, out_features={}, bias={}'.format(self.in_features, self.out_features, self.bias is not None) elif isinstance(self.activation, vector): ret = 'in_features={}, out_features={}, bias={}, activation={}\n'.format(self.in_features, self.out_features, self.bias is not None, self.activation.map(lambda x: touch(lambda: x.__name__, str(x)))) ret += "{}".format(self.in_features) for d, a in zip(self.dims[1:], self.activation): ret += '->{}->{}'.format(d, touch(lambda: a.__name__, str(a))) return ret else: ret = 'in_features={}, out_features={}, bias={}, activation={}'.format(self.in_features, self.out_features, self.bias is not None, touch(lambda: self.activation.__name__, str(self.activation))) return ret def regulization_loss(self, p=2): if self.hidden_dim is None: if p == 2: return self.weight.square().sum() if p == 1: return self.weight.abs().sum() return (self.weight.abs() ** p).sum() else: reg = [] for w in self.weight: reg.append((w.abs() ** p).sum()) return sum(reg)
class group_relaxed_TF1Conv2d(Module): """Implementation of TF1 regularization for the feature maps of a convolutional layer""" def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, lamba=1., alpha=1., beta=4., weight_decay = 1., **kwargs): """ :param in_channels: Number of input channels :param out_channels: Number of output channels :param kernel_size: size of the kernel :param stride: stride for the convolution :param padding: padding for the convolution :param dilation: dilation factor for the convolution :param groups: how many groups we will assume in the convolution :param bias: whether we will use a bias :param lamba: strength of the TFL regularization """ super(group_relaxed_TF1Conv2d, self).__init__() self.floatTensor = torch.FloatTensor if not torch.cuda.is_available() else torch.cuda.FloatTensor self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = pair(kernel_size) self.stride = pair(stride) self.padding = pair(padding) self.dilation = pair(dilation) self.output_padding = pair(0) self.groups = groups self.lamba = lamba self.alpha = alpha self.beta = beta self.lamba1 = self.lamba/self.beta self.weight_decay = weight_decay self.weight = Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) self.u = torch.rand(out_channels, in_channels // groups, *self.kernel_size) self.u = self.u.to('cuda') if bias: self.bias = Parameter(torch.Tensor(out_channels)) else: self.register_parameter('bias', None) self.reset_parameters() self.input_shape = None print(self) def reset_parameters(self): init.kaiming_normal(self.weight, mode='fan_in') if self.bias is not None: self.bias.data.normal_(0,1e-2) def phi(self,x): phi_x = torch.acos(1-27*(self.lamba1*self.alpha*(self.alpha+1))/(2*(self.alpha+x.abs())**3)) return phi_x def g(self,x): g_x = x.sign()*(2/3*(self.alpha + x.abs())*torch.cos(self.phi(x)/3)-2*self.alpha/3+x.abs()/3) return g_x def constrain_parameters(self, thres_std=1.): #self.weight.data = F.normalize(self.weight.data, p=2, dim=1) #print(torch.sum(self.weight.pow(2))) if self.lamba1 <= (self.alpha**2)/(2*(self.alpha+1)): t = self.lamba1*(self.alpha+1)/(self.alpha) else: t = np.sqrt(2*self.lamba1*(self.alpha+1))-self.alpha/2 self.u.data = self.weight.data.clone() self.u.data[self.u.data.abs() <=t] = 0 g_result = self.g(self.u) self.u.data[self.u.data.abs() > t] = g_result[self.u.data.abs() > t] def grow_beta(self, growth_factor): self.beta = self.beta*growth_factor self.lamba1 = self.lamba/self.beta def _reg_w(self, **kwargs): logpw = -self.beta*torch.sum(0.5*self.weight.add(-self.u).pow(2))-self.lamba*np.sqrt(self.in_channels*self.kernel_size[0]*self.kernel_size[1])*torch.sum(torch.pow(torch.sum(self.weight.pow(2),3).sum(2).sum(1),0.5)) logpb = 0 if self.bias is not None: logpb = - torch.sum(self.weight_decay * .5 * (self.bias.pow(2))) return logpw+logpb def regularization(self): return self._reg_w() def count_zero_u(self): total = np.prod(self.u.size()) zero = total - self.u.nonzero().size(0) return zero def count_zero_w(self): return torch.sum((self.weight.abs()<1e-5).int()).item() def count_active_neuron(self): return torch.sum((torch.sum(self.weight.abs(),3).sum(2).sum(1)/(self.in_channels*self.kernel_size[0]*self.kernel_size[1]))>1e-5).item() def count_total_neuron(self): return self.out_channels def count_weight(self): return np.prod(self.u.size()) def count_expected_flops_and_l0(self): #ppos = self.out_channels ppos = torch.sum(torch.sum(self.weight.abs(),3).sum(2).sum(1)>0.001).item() n = self.kernel_size[0]*self.kernel_size[1]*self.in_channels flops_per_instance = n+(n-1) num_instances_per_filter = ((self.input_shape[1] -self.kernel_size[0]+2*self.padding[0])/self.stride[0]) + 1 num_instances_per_filter *=((self.input_shape[2] - self.kernel_size[1]+2*self.padding[1])/self.stride[1]) + 1 flops_per_filter = num_instances_per_filter * flops_per_instance expected_flops = flops_per_filter*ppos expected_l0 = n*ppos if self.bias is not None: expected_flops += num_instances_per_filter*ppos expected_l0 += ppos return expected_flops, expected_l0 def forward(self, input_): if self.input_shape is None: self.input_shape = input_.size() output = F.conv2d(input_, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) return output def __repr__(self): s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size} ' ', stride={stride}') if self.padding != (0,) * len(self.padding): s += ', padding={padding}' if self.dilation != (1,) * len(self.dilation): s += ', dilation={dilation}' if self.output_padding != (0,) * len(self.output_padding): s += ', output_padding={output_padding}' if self.groups != 1: s += ', groups={groups}' if self.bias is None: s += ', bias=False' s += ')' return s.format(name=self.__class__.__name__, **self.__dict__)
class FilterStripe(nn.Conv2d):#卷积层+FS层 def __init__(self, in_channels, out_channels, kernel_size=3, stride=1): super(FilterStripe, self).__init__(in_channels, out_channels, kernel_size, stride, kernel_size // 2, groups=1, bias=False) self.BrokenTarget = None self.FilterSkeleton = Parameter(torch.ones(self.out_channels, self.kernel_size[0], self.kernel_size[1]), requires_grad=True)#FS层初始化 def forward(self, x):#forward()是自动调用的,x:[N,通道数,width,height] if self.BrokenTarget is not None: #out:[N,通道数,width,height] out = torch.zeros(x.shape[0], self.FilterSkeleton.shape[0], int(np.ceil(x.shape[2] / self.stride[0])), int(np.ceil(x.shape[3] / self.stride[1])))#ceil() 函数返回数字的上入整数 if x.is_cuda: out = out.cuda() x = F.conv2d(x, self.weight)#卷积输出 l, h = 0, 0 for i in range(self.BrokenTarget.shape[0]): for j in range(self.BrokenTarget.shape[1]): h += self.FilterSkeleton[:, i, j].sum().item()#FS层每个通道对应的值相加 out[:, self.FilterSkeleton[:, i, j]] += self.shift(x[:, l:h], i, j)[:, :, ::self.stride[0], ::self.stride[1]]#获得每个通道对应索引的输出 l += self.FilterSkeleton[:, i, j].sum().item() return out#输出 else: #unsqueeze(1)在第二个维度增加一个维度 return F.conv2d(x, self.weight * self.FilterSkeleton.unsqueeze(1), stride=self.stride, padding=self.padding, groups=self.groups) def prune_in(self, in_mask=None):#in_mask掩膜 #self.weight.shape:[out_channel,k,k,in_channel] print(self.weight.shape) self.weight = Parameter(self.weight[:, in_mask])#?????????? print(self.weight) self.in_channels = in_mask.sum().item() def prune_out(self, threshold):#threshold为阈值 out_mask = (self.FilterSkeleton.abs() > threshold).sum(dim=(1, 2)) != 0#获得掩膜 if out_mask.sum() == 0: print(out_mask.sum()) out_mask[0] = True self.weight = Parameter(self.weight[out_mask])#卷积核掩膜化 self.FilterSkeleton = Parameter(self.FilterSkeleton[out_mask], requires_grad=True)#FS层掩膜化 self.out_channels = out_mask.sum().item()#获取输出通道 return out_mask#掩膜 def _break(self, threshold): self.weight = Parameter(self.weight * self.FilterSkeleton.unsqueeze(1))#卷积核与FS层相乘 self.FilterSkeleton = Parameter((self.FilterSkeleton.abs() > threshold), requires_grad=False)#FS层大于阈值的为true if self.FilterSkeleton.sum() == 0: self.FilterSkeleton.data[0][0][0] = True self.out_channels = self.FilterSkeleton.sum().item() self.BrokenTarget = self.FilterSkeleton.sum(dim=0) self.kernel_size = (1, 1) #permute()将tensor的维度换位。 # print(self.FilterSkeleton.permute(1, 2, 0).reshape(-1)) # print(self.weight.permute(2, 3, 0, 1).reshape(-1, self.in_channels, 1, 1)) self.weight = Parameter(self.weight.permute(2, 3, 0, 1).reshape(-1, self.in_channels, 1, 1)[self.FilterSkeleton.permute(1, 2, 0).reshape(-1)])#掩膜化 # print(self.weight) def update_skeleton(self, sr, threshold): self.FilterSkeleton.grad.data.add_(sr * torch.sign(self.FilterSkeleton.data))#FS层的梯度更新,加入L1范数的导数 mask = self.FilterSkeleton.data.abs() > threshold self.FilterSkeleton.data.mul_(mask)#掩码化 self.FilterSkeleton.grad.data.mul_(mask)#掩码化 out_mask = mask.sum(dim=(1, 2)) != 0#???? return out_mask def shift(self, x, i, j): return F.pad(x, (self.BrokenTarget.shape[0] // 2 - j, j - self.BrokenTarget.shape[0] // 2, self.BrokenTarget.shape[0] // 2 - i, i - self.BrokenTarget.shape[1] // 2), 'constant', 0) def extra_repr(self): s = ('{BrokenTarget},{in_channels}, {out_channels}, kernel_size={kernel_size}' ', stride={stride}') return s.format(**self.__dict__)
class group_relaxed_L1L2Dense(Module): """Implementation of TFL regularization for the input units of a fully connected layer""" def __init__(self, in_features, out_features, bias=True, lamba=1., alpha=1., beta=4., weight_decay=1., **kwargs): """ :param in_features: input dimensionality :param out_features: output dimensionality :param bias: whether we use bias :param lamba: strength of the TF1 regularization """ super(group_relaxed_L1L2Dense, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = Parameter(torch.Tensor(in_features, out_features)) self.u = torch.rand(in_features, out_features) self.u = self.u.to('cuda') if bias: self.bias = Parameter(torch.Tensor(out_features)) else: self.register_parameter('bias', None) self.lamba = lamba self.alpha = alpha self.beta = beta self.lamba1 = self.lamba / self.beta self.weight_decay = weight_decay self.floatTensor = torch.FloatTensor if not torch.cuda.is_available( ) else torch.cuda.FloatTensor self.reset_parameters() print(self) def reset_parameters(self): init.kaiming_normal(self.weight, mode='fan_out') if self.bias is not None: self.bias.data.normal_(0, 1e-2) def constrain_parameters(self, **kwargs): norm_w = self.weight.data.norm(p=float('inf')) if norm_w > self.lamba1: m = Softshrink(self.lamba1) z = m(self.weight.data) self.u.data = z * (z.data.norm(p=2) + self.alpha * self.lamba1) / (z.data.norm(p=2)) elif norm_w == self.lamba1: self.u = self.weight.clone() self.u[self.u.abs() < lamba1] = 0 n = torch.sum(self.u != 0) self.u[self.u != 0] = self.weight.sign( ) * self.alpha * self.lamba1 / (n**(1 / 2)) elif (1 - self.alpha) * self.lamba1 < norm_w and norm_w < self.lamba1: self.u = self.weight.clone() max_idx = np.unravel_index(torch.argmax(self.u.cpu(), None), self.u.shape) max_value_sign = self.u[max_idx].sign() self.u[:] = 0 self.u[max_idx] = (norm_w + (self.alpha - 1) * self.lamba1) * max_value_sign else: self.u = self.weight.clone() self.u[:] = 0 def grow_beta(self, growth_factor): self.beta = self.beta * growth_factor self.lamba1 = self.lamba / self.beta def _reg_w(self, **kwargs): logpw = -self.beta * torch.sum( 0.5 * self.weight.add(-self.u).pow(2)) - self.lamba * np.sqrt( self.out_features) * torch.sum( torch.pow(torch.sum(self.weight.pow(2), 1), 0.5)) logpb = 0 if self.bias is not None: logpb = -torch.sum(self.weight_decay * .5 * (self.bias.pow(2))) return logpw + logpb def regularization(self): return self._reg_w() def count_zero_u(self): total = np.prod(self.u.size()) zero = total - self.u.nonzero().size(0) return zero def count_zero_w(self): return torch.sum((self.weight.abs() < 1e-5).int()).item() def count_weight(self): return np.prod(self.u.size()) def count_active_neuron(self): return torch.sum( torch.sum(self.weight.abs() / self.out_features, 1) > 1e-5).item() def count_total_neuron(self): return self.in_features def count_expected_flops_and_l0(self): ppos = torch.sum(self.weight.abs() > 0.000001).item() expected_flops = (2 * ppos - 1) * self.out_features expected_l0 = ppos * self.out_features if self.bias is not None: expected_flops += self.out_features expected_l0 += self.out_features return expected_flops, expected_l0 def forward(self, input): output = input.mm(self.weight) if self.bias is not None: output.add_(self.bias.view(1, self.out_features).expand_as(output)) return output def __repr__(self): return self.__class__.__name__+' (' \ + str(self.in_features) + ' -> ' \ + str(self.out_features) + ', lambda: ' \ + str(self.lamba) + ')'
class SubnetLinear(nn.Linear): # self.k is the % of weights remaining, a real number in [0,1] # self.popup_scores is a Parameter which has the same shape as self.weight # Gradients to self.weight, self.bias have been turned off. def __init__(self, in_features, out_features, bias=True): super(SubnetLinear, self).__init__(in_features, out_features, bias=True) # Weight pruning # self.popup_scores = Parameter(torch.Tensor(self.weight.shape)) # Channel Finetuning or Resume Pruning # self.popup_scores = Parameter(torch.Tensor(torch.Size([1,self.weight.shape[1]]))) # Channel Pruning self.popup_scores = Parameter( torch.Tensor(torch.Size([self.weight.shape[0], 1]))) nn.init.kaiming_uniform_(self.popup_scores, a=math.sqrt(5)) self.weight.requires_grad = False self.bias.requires_grad = False self.w = 0 # self.register_buffer('w', None) def set_prune_rate(self, k): self.k = k def forward(self, x): """ Unstructured Comparison remaining_weights = int(self.k * len(self.weight.flatten())) idx_same_top_weights_scores = list( set(torch.topk(self.weight.abs().flatten(), remaining_weights).indices.tolist()).intersection( set(torch.topk(self.popup_scores.abs().flatten(), remaining_weights).indices.tolist()))) num_remaining_weights = len(idx_same_top_weights_scores) print( f"SubnetLinear: Number of same indices for scores and weights that are left after pruning: " f"{num_remaining_weights}. These are {float(num_remaining_weights / remaining_weights)} percent of the " f"weights kept.") """ """ Structured Comparison remaining_filters = int(self.k * self.weight.shape[0]) idx_same_top_weights_scores = list(set( torch.topk(torch.linalg.norm(self.weight.abs().reshape(self.weight.shape[0], -1), 1, dim=1), remaining_filters).indices.tolist()).intersection( torch.topk(torch.linalg.norm(self.popup_scores.abs().reshape(self.popup_scores.shape[0], -1), 1, dim=1), remaining_filters).indices.tolist())) num_remaining_filters = len(idx_same_top_weights_scores) print( f"SubnetLinear: Number of same indices for filters that are left after pruning using scores or weights : " f"{num_remaining_filters}. These are {float(num_remaining_filters / remaining_filters)} percent of the " f"filters kept.") """ """ Channel Prune VGG16 global linear_nr if linear_nr == 3: linear_nr = 1 else: linear_nr += 1 # Get the subnetwork by sorting the scores. mask_linear_50 = [0.10107422, 0.1015625, 0.1015625] mask_linear_10 = [0.016601562, 0.015625, 0.015625] k = mask_linear_10[linear_nr-1] adj = GetSubnet.apply(self.popup_scores.abs(), self.k) """ # Fixed mask WRN Channel Prune # adj = GetSubnet.apply(self.popup_scores.abs(), 0.44140625) adj = GetSubnet.apply(self.popup_scores.abs(), self.k) # Use only the subnetwork in the forward pass. self.w = self.weight * adj x = F.linear(x, self.w, self.bias) return x
class group_relaxed_L1L2Conv2d(Module): """Implementation of TF1 regularization for the feature maps of a convolutional layer""" def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, lamba=1., alpha=1., beta=4., weight_decay=1., **kwargs): """ :param in_channels: Number of input channels :param out_channels: Number of output channels :param kernel_size: size of the kernel :param stride: stride for the convolution :param padding: padding for the convolution :param dilation: dilation factor for the convolution :param groups: how many groups we will assume in the convolution :param bias: whether we will use a bias :param lamba: strength of the TFL regularization """ super(group_relaxed_L1L2Conv2d, self).__init__() self.floatTensor = torch.FloatTensor if not torch.cuda.is_available( ) else torch.cuda.FloatTensor self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = pair(kernel_size) self.stride = pair(stride) self.padding = pair(padding) self.dilation = pair(dilation) self.output_padding = pair(0) self.groups = groups self.lamba = lamba self.alpha = alpha self.beta = beta self.lamba1 = self.lamba / self.beta self.weight_decay = weight_decay self.weight = Parameter( torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) self.u = torch.rand(out_channels, in_channels // groups, *self.kernel_size) self.u = self.u.to('cuda') if bias: self.bias = Parameter(torch.Tensor(out_channels)) else: self.register_parameter('bias', None) self.reset_parameters() self.input_shape = None print(self) def reset_parameters(self): init.kaiming_normal(self.weight, mode='fan_in') if self.bias is not None: self.bias.data.normal_(0, 1e-2) def constrain_parameters(self, **kwargs): norm_w = self.weight.data.norm(p=float('inf')) if norm_w > self.lamba1: m = Softshrink(self.lamba1) z = m(self.weight.data) self.u.data = z * (z.data.norm(p=2) + self.alpha * self.lamba1) / (z.data.norm(p=2)) elif norm_w == self.lamba1: self.u = self.weight.clone() self.u[self.u.abs() < lamba1] = 0 n = torch.sum(self.u != 0) self.u[self.u != 0] = self.weight.sign( ) * self.alpha * self.lamba1 / (n**(1 / 2)) elif (1 - self.alpha) * self.lamba1 < norm_w and norm_w < self.lamba1: self.u = self.weight.clone() max_idx = np.unravel_index(torch.argmax(self.u.cpu(), None), self.u.shape) max_value_sign = self.u[max_idx].sign() self.u[:] = 0 self.u[max_idx] = (norm_w + (self.alpha - 1) * self.lamba1) * max_value_sign else: self.u = self.weight.clone() self.u[:] = 0 def grow_beta(self, growth_factor): self.beta = self.beta * growth_factor self.lamba1 = self.lamba / self.beta def _reg_w(self, **kwargs): logpw = -self.beta * torch.sum( 0.5 * self.weight.add(-self.u).pow(2)) - self.lamba * np.sqrt( self.in_channels * self.kernel_size[0] * self.kernel_size[1] ) * torch.sum( torch.pow(torch.sum(self.weight.pow(2), 3).sum(2).sum(1), 0.5)) logpb = 0 if self.bias is not None: logpb = -torch.sum(self.weight_decay * .5 * (self.bias.pow(2))) return logpw + logpb def regularization(self): return self._reg_w() def count_zero_u(self): total = np.prod(self.u.size()) zero = total - self.u.nonzero().size(0) return zero def count_zero_w(self): return torch.sum((self.weight.abs() < 1e-5).int()).item() def count_active_neuron(self): return torch.sum((torch.sum(self.weight.abs(), 3).sum(2).sum(1) / (self.in_channels * self.kernel_size[0] * self.kernel_size[1])) > 1e-5).item() def count_total_neuron(self): return self.out_channels def count_weight(self): return np.prod(self.u.size()) def count_expected_flops_and_l0(self): #ppos = self.out_channels ppos = torch.sum( torch.sum(self.weight.abs(), 3).sum(2).sum(1) > 0.001).item() n = self.kernel_size[0] * self.kernel_size[1] * self.in_channels flops_per_instance = n + (n - 1) num_instances_per_filter = ( (self.input_shape[1] - self.kernel_size[0] + 2 * self.padding[0]) / self.stride[0]) + 1 num_instances_per_filter *= ( (self.input_shape[2] - self.kernel_size[1] + 2 * self.padding[1]) / self.stride[1]) + 1 flops_per_filter = num_instances_per_filter * flops_per_instance expected_flops = flops_per_filter * ppos expected_l0 = n * ppos if self.bias is not None: expected_flops += num_instances_per_filter * ppos expected_l0 += ppos return expected_flops, expected_l0 def forward(self, input_): if self.input_shape is None: self.input_shape = input_.size() output = F.conv2d(input_, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) return output def __repr__(self): s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size} ' ', stride={stride}') if self.padding != (0, ) * len(self.padding): s += ', padding={padding}' if self.dilation != (1, ) * len(self.dilation): s += ', dilation={dilation}' if self.output_padding != (0, ) * len(self.output_padding): s += ', output_padding={output_padding}' if self.groups != 1: s += ', groups={groups}' if self.bias is None: s += ', bias=False' s += ')' return s.format(name=self.__class__.__name__, **self.__dict__)
class lq_conv2d_orig(nn.Conv2d): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False, padding_mode='zeros', is_qt=False, tr_gamma=True, lq=False, block_num=-1, layer_num=-1, index=[], fwlq=False): super(lq_conv2d_orig, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode) self.block_num = block_num self.layer_num = layer_num self.index = index self.w_shape = self.weight.shape self.is_qt = is_qt self.lq = lq self.fwlq = fwlq #if groups != 1: # self.lq = False if lq: if fwlq: print("filter-wise learning to quantize") self.cw = Parameter(torch.ones([out_channels, 1])) self.dw = Parameter(torch.ones([out_channels, 1])) self.gamma = Parameter(torch.ones([out_channels, 1 ])) if tr_gamma else 1 else: self.cw = Parameter(torch.Tensor([1])) self.dw = Parameter(torch.Tensor([1])) self.gamma = Parameter(torch.Tensor([1])) if tr_gamma else 1 self.cx = Parameter(torch.Tensor([2])) self.dx = Parameter(torch.Tensor([2])) self.tr_gamma = tr_gamma def set_bit_width(self, w_bit, x_bit, initskip): self.w_bit = w_bit self.x_bit = x_bit if isinstance(x_bit, list): self.qx = [2**(bit) - 1 for bit in x_bit] self.theta_x = Parameter(torch.ones([len(x_bit)] / len(x_bit))) else: self.qx = 2**(x_bit) - 1 if isinstance(w_bit, list): self.qw = [2**(bit - 1) - 1 for bit in w_bit] self.theta_w = Parameter(torch.ones([len(w_bit)]) / len(w_bit)) else: self.qw = 2**(w_bit - 1) - 1 # Read filterwise bitwidth index if self.index != []: self.qw = torch.ones((self.w_shape[0], 1)) bit_max = 9 for i in range(bit_max): if len(self.index[self.block_num][self.layer_num][i]) == 0: continue else: idx = self.index[self.block_num][self.layer_num][i] self.qw[idx] = 2**(i + 1) - 1 self.qw = self.qw.cuda() # Initialize c, d if self.lq and not initskip: with torch.no_grad(): if self.fwlq: self.cw *= self.weight.abs().mean( ) #(dim=[1,2,3]).view((-1,1)) self.dw *= self.weight.std() #(dim=[1,2,3]).view((-1,1)) else: self.cw *= self.weight.abs().mean() self.dw *= self.weight.std() def bitops_count(self, soft_mask_w=None, soft_mask_x=None): x_shape = torch.Tensor([self.x_shape]) w_shape = torch.Tensor([self.weight.shape]) flops = x_shape.prod() flops *= w_shape.prod() # case 1: soft mask w, one value x if soft_mask_x == None and soft_mask_w != None: bitops = torch.Tensor(self.w_bit).cuda() * soft_mask_w * self.x_bit bitops *= flops bitops = bitops.sum() # case 0: 32-bit w and x elif not (soft_mask_x or soft_mask_w): bitops = torch.Tensor([32 * 32]).cuda() bitops *= flops return bitops def forward(self, input): self.x_shape = input.shape[2:] soft_mask_w = None if self.lq: w_abs = self.weight.abs() w_sign = self.weight.sign() w_abs = w_abs.view(self.w_shape[0], -1) w_sign = w_sign.view(self.w_shape[0], -1) eps = 1e-7 _dw = self.dw.abs() + eps #.abs() _dx = self.dx #s.abs() # yejun: d, gamma (no c) # Transformer_W w_mask1 = (w_abs <= _dw).type(torch.float).detach() w_mask2 = (w_abs > _dw).type(torch.float).detach() w_cal = w_abs / _dw nan_detect(w_cal) nan_detect(w_cal.pow(self.gamma)) w_hat = (w_mask2 * w_sign) + (w_mask1 * (w_cal).pow(self.gamma) * w_sign) nan_detect(w_hat) # Discretizer_W if isinstance(self.qw, list): # 1. learning bitwidth w_bar_list = [] for qw in self.qw: w_bar = Round.apply(w_hat * qw) / qw nan_detect(w_bar) w_bar_list.append(w_bar) soft_mask_w = nn.functional.gumbel_softmax(self.theta_w, tau=1, hard=False) w_bar = sum(w * theta for w, theta in zip(w_bar_list, soft_mask_w)) w_bar = w_bar.view(self.w_shape) else: # 2. fixed bitwidth w_bar = Round.apply(w_hat * self.qw) / self.qw nan_detect(w_bar) w_bar = w_bar.view(self.w_shape) nan_detect(w_bar) # Transformer X x_mask1 = (input <= self.dx).type(torch.float).detach() x_mask2 = (input > self.dx).type(torch.float).detach() x_cal = input / self.dx nan_detect(x_cal) x_hat = x_mask1 * x_cal + x_mask2 nan_detect(x_hat) # Discretizer X if isinstance(self.qx, list): # 1. learning bitwidth x_bar_list = [] for qx in self.qx: x_bar = Round.apply(x_hat * qx) / qx nan_detect(x_bar) x_bar_list.append(x_bar) soft_mask_x = nn.functional.gumbel_softmax(self.theta_x, tau=1, hard=False) x_bar = sum(x * theta for x, theta in zip(x_bar_list, soft_mask_x)) else: # 2. fixed bitwidth x_bar = Round.apply(x_hat * self.qx) / self.qx nan_detect(x_bar) y = F.conv2d(x_bar, w_bar, self.bias, self.stride, self.padding, self.dilation, self.groups) elif self.is_qt: if isinstance(self.qw, list): w_list = [] for qw in self.qw: w_list.append(quantize(self.weight, num_bits=qw)) soft_mask_w = nn.functional.gumbel_softmax(self.theta_w, tau=1, hard=False) w = sum(w_ * theta for w_, theta in zip(w_list, soft_mask_w)) #w= w_bar.view(self.w_shape) else: w = quantize(self.weight, num_bits=self.w_bit, block_num=self.block_num, layer_num=self.layer_num, multi=True, index=self.index) x = quantize(input, num_bits=self.x_bit, is_act=True) y = F.conv2d(x, w, self.bias, self.stride, self.padding, self.dilation, self.groups) else: y = F.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) flops = self.bitops_count(soft_mask_w=soft_mask_w, soft_mask_x=None) return y, flops
class AttentionReadout(Readout): def __init__( self, in_shape: Tuple[int, int, int], outdims: int, bias: bool, init_noise: float = 1e-3, attention_kernel: int = 1, attention_layers: int = 1, mean_activity: Optional[Mapping[str, float]] = None, feature_reg_weight: float = 1.0, gamma_readout: Optional[ float] = None, # deprecated, use feature_reg_weight instead **kwargs: Any, ) -> None: super().__init__() self.in_shape = in_shape self.outdims = outdims self.feature_reg_weight = self.resolve_deprecated_gamma_readout( feature_reg_weight, gamma_readout) # type: ignore[no-untyped-call] self.mean_activity = mean_activity c, w, h = in_shape self.features = Parameter(torch.Tensor(self.outdims, c)) attention = Sequential() for i in range(attention_layers - 1): attention.add_module( f"conv{i}", Conv2d(c, c, attention_kernel, padding=attention_kernel > 1), ) attention.add_module( f"norm{i}", BatchNorm2d(c)) # type: ignore[no-untyped-call] attention.add_module(f"nonlin{i}", ELU()) else: attention.add_module( f"conv{attention_layers}", Conv2d(c, outdims, attention_kernel, padding=attention_kernel > 1), ) self.attention = attention self.init_noise = init_noise if bias: bias_param = Parameter(torch.Tensor(self.outdims)) self.register_parameter("bias", bias_param) else: self.register_parameter("bias", None) self.initialize(mean_activity) @staticmethod def init_conv(m: Module) -> None: if isinstance(m, Conv2d): init.xavier_normal_(m.weight.data) if m.bias is not None: m.bias.data.fill_(0) def initialize_attention(self) -> None: self.apply(self.init_conv) def initialize( self, mean_activity: Optional[Mapping[str, float]] = None ) -> None: # type: ignore[override] if mean_activity is None: mean_activity = self.mean_activity self.features.data.normal_(0, self.init_noise) if self.bias is not None: self.initialize_bias( mean_activity=mean_activity) # type: ignore[no-untyped-call] self.initialize_attention() def feature_l1(self, reduction: Literal["sum", "mean", None] = "sum", average: Optional[bool] = None) -> torch.Tensor: return self.apply_reduction( self.features.abs(), reduction=reduction, average=average) # type: ignore[no-untyped-call,no-any-return] def regularizer(self, reduction: Literal["sum", "mean", None] = "sum", average: Optional[bool] = None) -> torch.Tensor: return self.feature_l1( reduction=reduction, average=average ) * self.feature_reg_weight # type: ignore[no-any-return] def forward(self, x: torch.Tensor, shift: Optional[Any] = None) -> torch.Tensor: attention = self.attention(x) b, c, w, h = attention.shape attention = F.softmax(attention.view(b, c, -1), dim=-1).view(b, c, w, h) y: torch.Tensor = torch.einsum("bnwh,bcwh->bcn", attention, x) # type: ignore[attr-defined] y = torch.einsum("bcn,nc->bn", y, self.features) # type: ignore[attr-defined] if self.bias is not None: y = y + self.bias return y def __repr__(self) -> str: return self.__class__.__name__ + " (" + "{} x {} x {}".format( *self.in_shape) + " -> " + str(self.outdims) + ")"
class group_relaxed_SCAD_Dense(Module): """Implementation of TFL regularization for the input units of a fully connected layer""" def __init__(self, in_features, out_features, bias=True, lamba=1., alpha = 3.7, beta = 4.0, weight_decay=1., **kwargs): """ :param in_features: input dimensionality :param out_features: output dimensionality :param bias: whether we use bias :param lamba: strength of the TF1 regularization """ super(group_relaxed_SCAD_Dense,self).__init__() self.in_features = in_features self.out_features = out_features self.weight = Parameter(torch.Tensor(in_features, out_features)) self.u = torch.rand(in_features, out_features) self.u = self.u.to('cuda') if bias: self.bias = Parameter(torch.Tensor(out_features)) else: self.register_parameter('bias', None) self.lamba = lamba self.alpha = alpha self.beta = beta self.lamba1 = self.lamba/self.beta self.weight_decay = weight_decay self.floatTensor = torch.FloatTensor if not torch.cuda.is_available() else torch.cuda.FloatTensor self.reset_parameters() print(self) def reset_parameters(self): init.kaiming_normal(self.weight, mode='fan_out') if self.bias is not None: self.bias.data.normal_(0,1e-2) def constrain_parameters(self, **kwargs): self.u = self.weight.clone() s = Softshrink(self.lamba1) #shrinkage on values with absolute value less than 2*lamba1 shrink_value = s(self.weight.data) self.u[self.weight.abs()<=2*self.lamba1] = shrink_value[self.weight.abs()<=2*self.lamba1] #modify values whose absolute values are between 2*lamba1 and alpha*lamba1 modify_weight = self.weight.data modify_weight = ((self.alpha - 1)*modify_weight-modify_weight.sign()*(3.7*self.lamba1))/(self.alpha -2) self.u[(self.weight.abs()>2*self.lamba1) & (self.weight.abs()<=self.alpha*self.lamba1)] = modify_weight[(self.weight.abs()>2*self.lamba1) & (self.weight.abs()<=self.alpha*self.lamba1)] def grow_beta(self, growth_factor): self.beta = self.beta*growth_factor self.lamba1 = self.lamba/self.beta def _reg_w(self, **kwargs): logpw = -self.beta*torch.sum(0.5*self.weight.add(-self.u).pow(2))-self.lamba*np.sqrt(self.out_features)*torch.sum(torch.pow(torch.sum(self.weight.pow(2),1),0.5)) logpb = 0 if self.bias is not None: logpb = - torch.sum(self.weight_decay * .5 * (self.bias.pow(2))) return logpw + logpb def regularization(self): return self._reg_w() def count_zero_u(self): total = np.prod(self.u.size()) zero = total - self.u.nonzero().size(0) return zero def count_zero_w(self): return torch.sum((self.weight.abs()<1e-5).int()).item() def count_weight(self): return np.prod(self.u.size()) def count_active_neuron(self): return torch.sum(torch.sum(self.weight.abs()/self.out_features,1)>1e-5).item() def count_total_neuron(self): return self.in_features def count_expected_flops_and_l0(self): ppos = torch.sum(self.weight.abs()>0.000001).item() expected_flops = (2*ppos-1)*self.out_features expected_l0 = ppos*self.out_features if self.bias is not None: expected_flops += self.out_features expected_l0 += self.out_features return expected_flops, expected_l0 def forward(self, input): output = input.mm(self.weight) if self.bias is not None: output.add_(self.bias.view(1, self.out_features).expand_as(output)) return output def __repr__(self): return self.__class__.__name__+' (' \ + str(self.in_features) + ' -> ' \ + str(self.out_features) + ', lambda: ' \ + str(self.lamba) + ')'
class sparse_group_lasso_Conv2d(Module): """Implementation of TF1 regularization for the feature maps of a convolutional layer""" def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, lamba=1., weight_decay=1., **kwargs): """ :param in_channels: Number of input channels :param out_channels: Number of output channels :param kernel_size: size of the kernel :param stride: stride for the convolution :param padding: padding for the convolution :param dilation: dilation factor for the convolution :param groups: how many groups we will assume in the convolution :param bias: whether we will use a bias :param lamba: strength of the TFL regularization """ super(sparse_group_lasso_Conv2d, self).__init__() self.floatTensor = torch.FloatTensor if not torch.cuda.is_available( ) else torch.cuda.FloatTensor self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = pair(kernel_size) self.stride = pair(stride) self.padding = pair(padding) self.dilation = pair(dilation) self.output_padding = pair(0) self.groups = groups self.lamba = lamba self.weight_decay = weight_decay self.weight = Parameter( torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) if bias: self.bias = Parameter(torch.Tensor(out_channels)) else: self.register_parameter('bias', None) self.reset_parameters() self.input_shape = None print(self) def reset_parameters(self): init.kaiming_normal(self.weight, mode='fan_in') if self.bias is not None: self.bias.data.normal_(0, 1e-2) def constrain_parameters(self, thres_std=1.): pass def _reg_w(self, **kwargs): logpw = -self.lamba * np.sqrt( self.in_channels * self.kernel_size[0] * self.kernel_size[1]) * torch.sum( torch.pow(torch.sum(self.weight.pow(2), 3).sum(2).sum(1), 0.5)) - torch.sum(self.lamba * self.weight.abs()) logpb = 0 if self.bias is not None: logpb = -torch.sum(self.weight_decay * .5 * (self.bias.pow(2))) return logpw + logpb def regularization(self): return self._reg_w() def count_zero_w(self): return torch.sum((self.weight.abs() < 1e-5).int()).item() def count_weight(self): return np.prod(self.weight.size()) def count_active_neuron(self): return torch.sum((torch.sum(self.weight.abs(), 3).sum(2).sum(1) / (self.in_channels * self.kernel_size[0] * self.kernel_size[1])) > 1e-5).item() def count_total_neuron(self): return self.out_channels def count_expected_flops_and_l0(self): #ppos = self.out_channels ppos = torch.sum( torch.sum(self.weight.abs(), 3).sum(2).sum(1) > 0.001).item() n = self.kernel_size[0] * self.kernel_size[1] * self.in_channels flops_per_instance = n + (n - 1) num_instances_per_filter = ( (self.input_shape[1] - self.kernel_size[0] + 2 * self.padding[0]) / self.stride[0]) + 1 num_instances_per_filter *= ( (self.input_shape[2] - self.kernel_size[1] + 2 * self.padding[1]) / self.stride[1]) + 1 flops_per_filter = num_instances_per_filter * flops_per_instance expected_flops = flops_per_filter * ppos expected_l0 = n * ppos if self.bias is not None: expected_flops += num_instances_per_filter * ppos expected_l0 += ppos return expected_flops, expected_l0 def forward(self, input_): if self.input_shape is None: self.input_shape = input_.size() output = F.conv2d(input_, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) return output def __repr__(self): s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size} ' ', stride={stride}') if self.padding != (0, ) * len(self.padding): s += ', padding={padding}' if self.dilation != (1, ) * len(self.dilation): s += ', dilation={dilation}' if self.output_padding != (0, ) * len(self.output_padding): s += ', output_padding={output_padding}' if self.groups != 1: s += ', groups={groups}' if self.bias is None: s += ', bias=False' s += ')' return s.format(name=self.__class__.__name__, **self.__dict__)
class CGES_Dense(Module): """Implementation of TFL regularization for the input units of a fully connected layer""" def __init__(self, in_features, out_features, bias=True, lamba=1., weight_decay=1., mu=0.5, **kwargs): """ :param in_features: input dimensionality :param out_features: output dimensionality :param bias: whether we use bias :param lamba: strength of the TF1 regularization """ super(CGES_Dense, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = Parameter(torch.Tensor(in_features, out_features)) if bias: self.bias = Parameter(torch.Tensor(out_features)) else: self.register_parameter('bias', None) self.lamba = lamba self.weight_decay = weight_decay self.floatTensor = torch.FloatTensor if not torch.cuda.is_available( ) else torch.cuda.FloatTensor self.mu = mu self.reset_parameters() print(self) def reset_parameters(self): init.kaiming_normal(self.weight, mode='fan_out') if self.bias is not None: self.bias.data.normal_(0, 1e-2) def constrain_parameters(self, **kwargs): pass def _reg_w(self, **kwargs): logpw = -(1 - self.mu) * self.lamba * torch.sum( torch.pow(torch.sum(self.weight.pow(2), 1), 0.5)) - self.mu * self.lamba / 2 * torch.sum( torch.sum(self.weight.abs(), 1)**2) logpb = 0 if self.bias is not None: logpb = -torch.sum(self.weight_decay * .5 * (self.bias.pow(2))) return logpw + logpb def set_mu(self, mu): self.mu = mu def regularization(self): return self._reg_w() def count_zero_w(self): return torch.sum((self.weight.abs() < 1e-5).int()).item() def count_weight(self): return np.prod(self.weight.size()) def count_active_neuron(self): return torch.sum( torch.sum(self.weight.abs() / self.out_features, 1) > 1e-5).item() def count_total_neuron(self): return self.in_features def count_expected_flops_and_l0(self): ppos = torch.sum(self.weight.abs() > 0.000001).item() expected_flops = (2 * ppos - 1) * self.out_features expected_l0 = ppos * self.out_features if self.bias is not None: expected_flops += self.out_features expected_l0 += self.out_features return expected_flops, expected_l0 def forward(self, input): output = input.mm(self.weight) if self.bias is not None: output.add_(self.bias.view(1, self.out_features).expand_as(output)) return output def __repr__(self): return self.__class__.__name__+' (' \ + str(self.in_features) + ' -> ' \ + str(self.out_features) + ', lambda: ' \ + str(self.lamba) + ')'