class group_relaxed_L1L2Conv2d(Module): """Implementation of TF1 regularization for the feature maps of a convolutional layer""" def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, lamba=1., alpha=1., beta=4., weight_decay=1., **kwargs): """ :param in_channels: Number of input channels :param out_channels: Number of output channels :param kernel_size: size of the kernel :param stride: stride for the convolution :param padding: padding for the convolution :param dilation: dilation factor for the convolution :param groups: how many groups we will assume in the convolution :param bias: whether we will use a bias :param lamba: strength of the TFL regularization """ super(group_relaxed_L1L2Conv2d, self).__init__() self.floatTensor = torch.FloatTensor if not torch.cuda.is_available( ) else torch.cuda.FloatTensor self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = pair(kernel_size) self.stride = pair(stride) self.padding = pair(padding) self.dilation = pair(dilation) self.output_padding = pair(0) self.groups = groups self.lamba = lamba self.alpha = alpha self.beta = beta self.lamba1 = self.lamba / self.beta self.weight_decay = weight_decay self.weight = Parameter( torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) self.u = torch.rand(out_channels, in_channels // groups, *self.kernel_size) self.u = self.u.to('cuda') if bias: self.bias = Parameter(torch.Tensor(out_channels)) else: self.register_parameter('bias', None) self.reset_parameters() self.input_shape = None print(self) def reset_parameters(self): init.kaiming_normal(self.weight, mode='fan_in') if self.bias is not None: self.bias.data.normal_(0, 1e-2) def constrain_parameters(self, **kwargs): norm_w = self.weight.data.norm(p=float('inf')) if norm_w > self.lamba1: m = Softshrink(self.lamba1) z = m(self.weight.data) self.u.data = z * (z.data.norm(p=2) + self.alpha * self.lamba1) / (z.data.norm(p=2)) elif norm_w == self.lamba1: self.u = self.weight.clone() self.u[self.u.abs() < lamba1] = 0 n = torch.sum(self.u != 0) self.u[self.u != 0] = self.weight.sign( ) * self.alpha * self.lamba1 / (n**(1 / 2)) elif (1 - self.alpha) * self.lamba1 < norm_w and norm_w < self.lamba1: self.u = self.weight.clone() max_idx = np.unravel_index(torch.argmax(self.u.cpu(), None), self.u.shape) max_value_sign = self.u[max_idx].sign() self.u[:] = 0 self.u[max_idx] = (norm_w + (self.alpha - 1) * self.lamba1) * max_value_sign else: self.u = self.weight.clone() self.u[:] = 0 def grow_beta(self, growth_factor): self.beta = self.beta * growth_factor self.lamba1 = self.lamba / self.beta def _reg_w(self, **kwargs): logpw = -self.beta * torch.sum( 0.5 * self.weight.add(-self.u).pow(2)) - self.lamba * np.sqrt( self.in_channels * self.kernel_size[0] * self.kernel_size[1] ) * torch.sum( torch.pow(torch.sum(self.weight.pow(2), 3).sum(2).sum(1), 0.5)) logpb = 0 if self.bias is not None: logpb = -torch.sum(self.weight_decay * .5 * (self.bias.pow(2))) return logpw + logpb def regularization(self): return self._reg_w() def count_zero_u(self): total = np.prod(self.u.size()) zero = total - self.u.nonzero().size(0) return zero def count_zero_w(self): return torch.sum((self.weight.abs() < 1e-5).int()).item() def count_active_neuron(self): return torch.sum((torch.sum(self.weight.abs(), 3).sum(2).sum(1) / (self.in_channels * self.kernel_size[0] * self.kernel_size[1])) > 1e-5).item() def count_total_neuron(self): return self.out_channels def count_weight(self): return np.prod(self.u.size()) def count_expected_flops_and_l0(self): #ppos = self.out_channels ppos = torch.sum( torch.sum(self.weight.abs(), 3).sum(2).sum(1) > 0.001).item() n = self.kernel_size[0] * self.kernel_size[1] * self.in_channels flops_per_instance = n + (n - 1) num_instances_per_filter = ( (self.input_shape[1] - self.kernel_size[0] + 2 * self.padding[0]) / self.stride[0]) + 1 num_instances_per_filter *= ( (self.input_shape[2] - self.kernel_size[1] + 2 * self.padding[1]) / self.stride[1]) + 1 flops_per_filter = num_instances_per_filter * flops_per_instance expected_flops = flops_per_filter * ppos expected_l0 = n * ppos if self.bias is not None: expected_flops += num_instances_per_filter * ppos expected_l0 += ppos return expected_flops, expected_l0 def forward(self, input_): if self.input_shape is None: self.input_shape = input_.size() output = F.conv2d(input_, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) return output def __repr__(self): s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size} ' ', stride={stride}') if self.padding != (0, ) * len(self.padding): s += ', padding={padding}' if self.dilation != (1, ) * len(self.dilation): s += ', dilation={dilation}' if self.output_padding != (0, ) * len(self.output_padding): s += ', output_padding={output_padding}' if self.groups != 1: s += ', groups={groups}' if self.bias is None: s += ', bias=False' s += ')' return s.format(name=self.__class__.__name__, **self.__dict__)
class group_relaxed_L1L2Dense(Module): """Implementation of TFL regularization for the input units of a fully connected layer""" def __init__(self, in_features, out_features, bias=True, lamba=1., alpha=1., beta=4., weight_decay=1., **kwargs): """ :param in_features: input dimensionality :param out_features: output dimensionality :param bias: whether we use bias :param lamba: strength of the TF1 regularization """ super(group_relaxed_L1L2Dense, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = Parameter(torch.Tensor(in_features, out_features)) self.u = torch.rand(in_features, out_features) self.u = self.u.to('cuda') if bias: self.bias = Parameter(torch.Tensor(out_features)) else: self.register_parameter('bias', None) self.lamba = lamba self.alpha = alpha self.beta = beta self.lamba1 = self.lamba / self.beta self.weight_decay = weight_decay self.floatTensor = torch.FloatTensor if not torch.cuda.is_available( ) else torch.cuda.FloatTensor self.reset_parameters() print(self) def reset_parameters(self): init.kaiming_normal(self.weight, mode='fan_out') if self.bias is not None: self.bias.data.normal_(0, 1e-2) def constrain_parameters(self, **kwargs): norm_w = self.weight.data.norm(p=float('inf')) if norm_w > self.lamba1: m = Softshrink(self.lamba1) z = m(self.weight.data) self.u.data = z * (z.data.norm(p=2) + self.alpha * self.lamba1) / (z.data.norm(p=2)) elif norm_w == self.lamba1: self.u = self.weight.clone() self.u[self.u.abs() < lamba1] = 0 n = torch.sum(self.u != 0) self.u[self.u != 0] = self.weight.sign( ) * self.alpha * self.lamba1 / (n**(1 / 2)) elif (1 - self.alpha) * self.lamba1 < norm_w and norm_w < self.lamba1: self.u = self.weight.clone() max_idx = np.unravel_index(torch.argmax(self.u.cpu(), None), self.u.shape) max_value_sign = self.u[max_idx].sign() self.u[:] = 0 self.u[max_idx] = (norm_w + (self.alpha - 1) * self.lamba1) * max_value_sign else: self.u = self.weight.clone() self.u[:] = 0 def grow_beta(self, growth_factor): self.beta = self.beta * growth_factor self.lamba1 = self.lamba / self.beta def _reg_w(self, **kwargs): logpw = -self.beta * torch.sum( 0.5 * self.weight.add(-self.u).pow(2)) - self.lamba * np.sqrt( self.out_features) * torch.sum( torch.pow(torch.sum(self.weight.pow(2), 1), 0.5)) logpb = 0 if self.bias is not None: logpb = -torch.sum(self.weight_decay * .5 * (self.bias.pow(2))) return logpw + logpb def regularization(self): return self._reg_w() def count_zero_u(self): total = np.prod(self.u.size()) zero = total - self.u.nonzero().size(0) return zero def count_zero_w(self): return torch.sum((self.weight.abs() < 1e-5).int()).item() def count_weight(self): return np.prod(self.u.size()) def count_active_neuron(self): return torch.sum( torch.sum(self.weight.abs() / self.out_features, 1) > 1e-5).item() def count_total_neuron(self): return self.in_features def count_expected_flops_and_l0(self): ppos = torch.sum(self.weight.abs() > 0.000001).item() expected_flops = (2 * ppos - 1) * self.out_features expected_l0 = ppos * self.out_features if self.bias is not None: expected_flops += self.out_features expected_l0 += self.out_features return expected_flops, expected_l0 def forward(self, input): output = input.mm(self.weight) if self.bias is not None: output.add_(self.bias.view(1, self.out_features).expand_as(output)) return output def __repr__(self): return self.__class__.__name__+' (' \ + str(self.in_features) + ' -> ' \ + str(self.out_features) + ', lambda: ' \ + str(self.lamba) + ')'