class InnerProductKernel(Kernel): def __init__(self, ndim, input_map=lambda x: x, diagonal=True): super(InnerProductKernel, self).__init__(input_map) self.diag = diagonal if diagonal: self.sigma_sqrt = Parameter(torch.FloatTensor(ndim)) else: self.sigma_sqrt = Parameter(torch.FloatTensor(ndim, ndim)) def reset_parameters(self): super(InnerProductKernel, self).reset_parameters() self.sigma_sqrt.data.normal_() def out_of_bounds(self, vec=None): if vec is None: return super(InnerProductKernel, self).out_of_bounds(self.log_amp) else: return not super(InnerProductKernel, self).out_of_bounds(vec[:1]) def n_params(self): return super(InnerProductKernel, self).n_params() + self.sigma_chol.numel() def param_to_vec(self): return torch.cat([self.log_amp.data, self.sigma_chol.data]) def vec_to_param(self, vec): self.log_amp.data = vec[0:1] self.sigma_sqrt.data = vec[1:] def prior(self, vec): return super(InnerProductKernel, self).prior(vec[:1]) + smp.normal(vec[1:]) def forward(self, input1, input2=None): stabilizer = 0 if input2 is None: input2 = input1 stabilizer = Variable(torch.diag(input1.data.new(input1.size(0)).fill_(1e-6 * math.exp(self.log_amp.data[0])))) gram_mat = inner_product.InnerProductKernel.apply(self.input_map(input1), self.input_map(input2), self.log_amp, self.sigma_sqrt if self.diag else self.sigma_sqrt.view(int(self.sigma_sqrt.numel() ** 0.5), -1)) return gram_mat + stabilizer def __repr__(self): return self.__class__.__name__ + ' (diag=' + ('True' if self.sigma_sqrt.dim()==1 else 'False') + ')'
class SparseTensor(nn.Module): def __init__(self, tensor_size, initial_sparsity, sub_kernel_granularity=4): super(SparseTensor, self).__init__() self.s_tensor = Parameter(torch.Tensor(torch.Size(tensor_size))) self.initial_sparsity = initial_sparsity self.sub_kernel_granularity = sub_kernel_granularity assert self.s_tensor.dim() == 2 or self.s_tensor.dim( ) == 4, "can only do 2D or 4D sparse tensors" trailing_dimensions = [1] * (4 - sub_kernel_granularity) self.register_buffer( 'mask', torch.Tensor(*(tensor_size[:sub_kernel_granularity]))) self.normalize_coeff = np.prod( tensor_size[sub_kernel_granularity:]).item() self.conv_tensor = False if self.s_tensor.dim() == 2 else True self.mask.zero_() flat_mask = self.mask.view(-1) indices = np.arange(flat_mask.size(0)) np.random.shuffle(indices) flat_mask[indices[:int((1 - initial_sparsity) * flat_mask.size(0) + 0.1)]] = 1 self.grown_indices = None self.init_parameters() self.reinitialize_unused() self.tensor_sign = torch.sign(self.s_tensor.data.view(-1)) def reinitialize_unused(self, reinitialize_unused_to_zero=True): unused_positions = (self.mask < 0.5) if reinitialize_unused_to_zero: self.s_tensor.data[unused_positions] = torch.zeros( self.s_tensor.data[unused_positions].size()).to( self.s_tensor.device) else: if self.conv_tensor: n = self.s_tensor.size(0) * self.s_tensor.size( 2) * self.s_tensor.size(3) self.s_tensor.data[unused_positions] = torch.zeros( self.s_tensor.data[unused_positions].size()).normal_( 0, math.sqrt(2. / n)).to(self.s_tensor.device) else: stdv = 1. / math.sqrt(self.s_tensor.size(1)) self.s_tensor.data[unused_positions] = torch.zeros( self.s_tensor.data[unused_positions].size()).normal_( 0, stdv).to(self.s_tensor.device) def init_parameters(self): stdv = 1 / math.sqrt(np.prod(self.s_tensor.size()[1:])) self.s_tensor.data.uniform_(-stdv, stdv) def prune_sign_change(self, reinitialize_unused_to_zero=True, enable_print=False): W_flat = self.s_tensor.data.view(-1) new_tensor_sign = torch.sign(W_flat) mask_flat = self.mask.view(-1) mask_indices = torch.nonzero(mask_flat > 0.5).view(-1) sign_change_indices = mask_indices[( (new_tensor_sign[mask_indices] * self.tensor_sign[mask_indices].to(new_tensor_sign.device)) < -0.5).nonzero().view(-1)] mask_flat[sign_change_indices] = 0 self.reinitialize_unused(reinitialize_unused_to_zero) cutoff = sign_change_indices.numel() if enable_print: print('pruned {} connections'.format(cutoff)) if self.grown_indices is not None and enable_print: overlap = np.intersect1d(sign_change_indices.cpu().numpy(), self.grown_indices.cpu().numpy()) print('pruned {} ({} %) just grown weights'.format( overlap.size, overlap.size * 100.0 / self.grown_indices.size(0) if self.grown_indices.size(0) > 0 else 0.0)) self.tensor_sign = new_tensor_sign return sign_change_indices def prune_small_connections(self, prune_fraction, reinitialize_unused_to_zero=True): if self.conv_tensor and self.sub_kernel_granularity < 4: W_flat = self.s_tensor.abs().sum( list(np.arange(self.sub_kernel_granularity, 4))).view(-1) / self.normalize_coeff else: W_flat = self.s_tensor.data.view(-1) mask_flat = self.mask.view(-1) mask_indices = torch.nonzero(mask_flat > 0.5).view(-1) W_masked = W_flat[mask_indices] sorted_W_indices = torch.sort(torch.abs(W_masked))[1] cutoff = int(prune_fraction * W_masked.numel()) + 1 mask_flat[mask_indices[sorted_W_indices[:cutoff]]] = 0 self.reinitialize_unused(reinitialize_unused_to_zero) # print('pruned {} connections'.format(cutoff)) # if self.grown_indices is not None: # overlap = np.intersect1d(mask_indices[sorted_W_indices[:cutoff]].cpu().numpy(),self.grown_indices.cpu().numpy()) #print('pruned {} ({} %) just grown weights'.format(overlap.size,overlap.size * 100.0 / self.grown_indices.size(0))) return mask_indices[sorted_W_indices[:cutoff]] def prune_threshold(self, threshold, reinitialize_unused_to_zero=True): if self.conv_tensor and self.sub_kernel_granularity < 4: W_flat = self.s_tensor.abs().sum( list(np.arange(self.sub_kernel_granularity, 4))).view(-1) / self.normalize_coeff else: W_flat = self.s_tensor.data.view(-1) mask_flat = self.mask.view(-1) mask_indices = torch.nonzero(mask_flat > 0.5).view(-1) W_masked = W_flat[mask_indices] prune_indices = (W_masked.abs() < threshold).nonzero().view(-1) if mask_indices.size(0) == prune_indices.size(0): print('removing all. keeping one') prune_indices = prune_indices[1:] mask_flat[mask_indices[prune_indices]] = 0 # if mask_indices.numel() > 0 : # print('pruned {}/{}({:.2f}) connections'.format(prune_indices.numel(),mask_indices.numel(),prune_indices.numel()/mask_indices.numel())) # if self.grown_indices is not None and self.grown_indices.size(0) != 0 : # overlap = np.intersect1d(mask_indices[prune_indices].cpu().numpy(),self.grown_indices.cpu().numpy()) # print('pruned {} ({} %) just grown weights'.format(overlap.size,overlap.size * 100.0 / self.grown_indices.size(0))) self.reinitialize_unused(reinitialize_unused_to_zero) return mask_indices[prune_indices] def grow_random(self, grow_fraction, pruned_indices=None, enable_print=False, n_to_add=None): mask_flat = self.mask.view(-1) mask_zero_indices = torch.nonzero(mask_flat < 0.5).view(-1) if pruned_indices is not None: cutoff = pruned_indices.size(0) mask_zero_indices = torch.Tensor( np.setdiff1d(mask_zero_indices.cpu().numpy(), pruned_indices.cpu().numpy())).long().to( mask_zero_indices.device) else: cutoff = int(grow_fraction * mask_zero_indices.size(0)) if n_to_add is not None: cutoff = n_to_add if mask_zero_indices.numel() < cutoff: print('******no place to grow {} connections, growing {} instead'. format(cutoff, mask_zero_indices.numel())) cutoff = mask_zero_indices.numel() if enable_print: print('grown {} connections'.format(cutoff)) self.grown_indices = mask_zero_indices[torch.randperm( mask_zero_indices.numel())][:cutoff] mask_flat[self.grown_indices] = 1 return cutoff def get_sparsity(self): active_elements = self.mask.sum() * np.prod( self.s_tensor.size()[self.sub_kernel_granularity:]).item() return (active_elements, 1 - active_elements / self.s_tensor.numel()) def forward(self): if self.conv_tensor: return self.mask.view( *(self.mask.size() + (1, ) * (4 - self.sub_kernel_granularity))) * self.s_tensor else: return self.mask * self.s_tensor def extra_repr(self): return 'full tensor size : {} , sparsity mask : {} , sub kernel granularity : {}'.format( self.s_tensor.size(), self.get_sparsity(), self.sub_kernel_granularity)