Beispiel #1
0
class InnerProductKernel(Kernel):

	def __init__(self, ndim, input_map=lambda x: x, diagonal=True):
		super(InnerProductKernel, self).__init__(input_map)
		self.diag = diagonal
		if diagonal:
			self.sigma_sqrt = Parameter(torch.FloatTensor(ndim))
		else:
			self.sigma_sqrt = Parameter(torch.FloatTensor(ndim, ndim))

	def reset_parameters(self):
		super(InnerProductKernel, self).reset_parameters()
		self.sigma_sqrt.data.normal_()

	def out_of_bounds(self, vec=None):
		if vec is None:
			return super(InnerProductKernel, self).out_of_bounds(self.log_amp)
		else:
			return not super(InnerProductKernel, self).out_of_bounds(vec[:1])

	def n_params(self):
		return super(InnerProductKernel, self).n_params() + self.sigma_chol.numel()

	def param_to_vec(self):
		return torch.cat([self.log_amp.data, self.sigma_chol.data])

	def vec_to_param(self, vec):
		self.log_amp.data = vec[0:1]
		self.sigma_sqrt.data = vec[1:]

	def prior(self, vec):
		return super(InnerProductKernel, self).prior(vec[:1]) + smp.normal(vec[1:])

	def forward(self, input1, input2=None):
		stabilizer = 0
		if input2 is None:
			input2 = input1
			stabilizer = Variable(torch.diag(input1.data.new(input1.size(0)).fill_(1e-6 * math.exp(self.log_amp.data[0]))))
		gram_mat = inner_product.InnerProductKernel.apply(self.input_map(input1), self.input_map(input2), self.log_amp, self.sigma_sqrt if self.diag else self.sigma_sqrt.view(int(self.sigma_sqrt.numel() ** 0.5), -1))
		return gram_mat + stabilizer

	def __repr__(self):
		return self.__class__.__name__ + ' (diag=' + ('True' if self.sigma_sqrt.dim()==1 else 'False') + ')'
class SparseTensor(nn.Module):
    def __init__(self,
                 tensor_size,
                 initial_sparsity,
                 sub_kernel_granularity=4):
        super(SparseTensor, self).__init__()
        self.s_tensor = Parameter(torch.Tensor(torch.Size(tensor_size)))
        self.initial_sparsity = initial_sparsity
        self.sub_kernel_granularity = sub_kernel_granularity

        assert self.s_tensor.dim() == 2 or self.s_tensor.dim(
        ) == 4, "can only do 2D or 4D sparse tensors"

        trailing_dimensions = [1] * (4 - sub_kernel_granularity)
        self.register_buffer(
            'mask', torch.Tensor(*(tensor_size[:sub_kernel_granularity])))

        self.normalize_coeff = np.prod(
            tensor_size[sub_kernel_granularity:]).item()

        self.conv_tensor = False if self.s_tensor.dim() == 2 else True

        self.mask.zero_()
        flat_mask = self.mask.view(-1)
        indices = np.arange(flat_mask.size(0))
        np.random.shuffle(indices)
        flat_mask[indices[:int((1 - initial_sparsity) * flat_mask.size(0) +
                               0.1)]] = 1

        self.grown_indices = None
        self.init_parameters()
        self.reinitialize_unused()

        self.tensor_sign = torch.sign(self.s_tensor.data.view(-1))

    def reinitialize_unused(self, reinitialize_unused_to_zero=True):
        unused_positions = (self.mask < 0.5)
        if reinitialize_unused_to_zero:
            self.s_tensor.data[unused_positions] = torch.zeros(
                self.s_tensor.data[unused_positions].size()).to(
                    self.s_tensor.device)
        else:
            if self.conv_tensor:
                n = self.s_tensor.size(0) * self.s_tensor.size(
                    2) * self.s_tensor.size(3)
                self.s_tensor.data[unused_positions] = torch.zeros(
                    self.s_tensor.data[unused_positions].size()).normal_(
                        0, math.sqrt(2. / n)).to(self.s_tensor.device)
            else:
                stdv = 1. / math.sqrt(self.s_tensor.size(1))
                self.s_tensor.data[unused_positions] = torch.zeros(
                    self.s_tensor.data[unused_positions].size()).normal_(
                        0, stdv).to(self.s_tensor.device)

    def init_parameters(self):
        stdv = 1 / math.sqrt(np.prod(self.s_tensor.size()[1:]))

        self.s_tensor.data.uniform_(-stdv, stdv)

    def prune_sign_change(self,
                          reinitialize_unused_to_zero=True,
                          enable_print=False):
        W_flat = self.s_tensor.data.view(-1)

        new_tensor_sign = torch.sign(W_flat)
        mask_flat = self.mask.view(-1)

        mask_indices = torch.nonzero(mask_flat > 0.5).view(-1)

        sign_change_indices = mask_indices[(
            (new_tensor_sign[mask_indices] *
             self.tensor_sign[mask_indices].to(new_tensor_sign.device)) <
            -0.5).nonzero().view(-1)]

        mask_flat[sign_change_indices] = 0
        self.reinitialize_unused(reinitialize_unused_to_zero)

        cutoff = sign_change_indices.numel()

        if enable_print:
            print('pruned {}  connections'.format(cutoff))
        if self.grown_indices is not None and enable_print:
            overlap = np.intersect1d(sign_change_indices.cpu().numpy(),
                                     self.grown_indices.cpu().numpy())
            print('pruned {} ({} %) just grown weights'.format(
                overlap.size, overlap.size * 100.0 / self.grown_indices.size(0)
                if self.grown_indices.size(0) > 0 else 0.0))

        self.tensor_sign = new_tensor_sign
        return sign_change_indices

    def prune_small_connections(self,
                                prune_fraction,
                                reinitialize_unused_to_zero=True):
        if self.conv_tensor and self.sub_kernel_granularity < 4:
            W_flat = self.s_tensor.abs().sum(
                list(np.arange(self.sub_kernel_granularity,
                               4))).view(-1) / self.normalize_coeff
        else:
            W_flat = self.s_tensor.data.view(-1)

        mask_flat = self.mask.view(-1)

        mask_indices = torch.nonzero(mask_flat > 0.5).view(-1)

        W_masked = W_flat[mask_indices]

        sorted_W_indices = torch.sort(torch.abs(W_masked))[1]

        cutoff = int(prune_fraction * W_masked.numel()) + 1

        mask_flat[mask_indices[sorted_W_indices[:cutoff]]] = 0
        self.reinitialize_unused(reinitialize_unused_to_zero)

        #        print('pruned {}  connections'.format(cutoff))
        #        if self.grown_indices is not None:
        #            overlap = np.intersect1d(mask_indices[sorted_W_indices[:cutoff]].cpu().numpy(),self.grown_indices.cpu().numpy())
        #print('pruned {} ({} %) just grown weights'.format(overlap.size,overlap.size * 100.0 / self.grown_indices.size(0)))

        return mask_indices[sorted_W_indices[:cutoff]]

    def prune_threshold(self, threshold, reinitialize_unused_to_zero=True):
        if self.conv_tensor and self.sub_kernel_granularity < 4:
            W_flat = self.s_tensor.abs().sum(
                list(np.arange(self.sub_kernel_granularity,
                               4))).view(-1) / self.normalize_coeff
        else:
            W_flat = self.s_tensor.data.view(-1)

        mask_flat = self.mask.view(-1)

        mask_indices = torch.nonzero(mask_flat > 0.5).view(-1)

        W_masked = W_flat[mask_indices]

        prune_indices = (W_masked.abs() < threshold).nonzero().view(-1)

        if mask_indices.size(0) == prune_indices.size(0):
            print('removing all. keeping one')
            prune_indices = prune_indices[1:]

        mask_flat[mask_indices[prune_indices]] = 0

        #       if mask_indices.numel() > 0 :
        #           print('pruned {}/{}({:.2f})  connections'.format(prune_indices.numel(),mask_indices.numel(),prune_indices.numel()/mask_indices.numel()))

        #        if self.grown_indices is not None and self.grown_indices.size(0) != 0 :
        #            overlap = np.intersect1d(mask_indices[prune_indices].cpu().numpy(),self.grown_indices.cpu().numpy())
        #            print('pruned {} ({} %) just grown weights'.format(overlap.size,overlap.size * 100.0 / self.grown_indices.size(0)))

        self.reinitialize_unused(reinitialize_unused_to_zero)

        return mask_indices[prune_indices]

    def grow_random(self,
                    grow_fraction,
                    pruned_indices=None,
                    enable_print=False,
                    n_to_add=None):
        mask_flat = self.mask.view(-1)
        mask_zero_indices = torch.nonzero(mask_flat < 0.5).view(-1)
        if pruned_indices is not None:
            cutoff = pruned_indices.size(0)
            mask_zero_indices = torch.Tensor(
                np.setdiff1d(mask_zero_indices.cpu().numpy(),
                             pruned_indices.cpu().numpy())).long().to(
                                 mask_zero_indices.device)
        else:
            cutoff = int(grow_fraction * mask_zero_indices.size(0))

        if n_to_add is not None:
            cutoff = n_to_add

        if mask_zero_indices.numel() < cutoff:
            print('******no place to grow {} connections, growing {} instead'.
                  format(cutoff, mask_zero_indices.numel()))
            cutoff = mask_zero_indices.numel()

        if enable_print:
            print('grown {}  connections'.format(cutoff))

        self.grown_indices = mask_zero_indices[torch.randperm(
            mask_zero_indices.numel())][:cutoff]
        mask_flat[self.grown_indices] = 1

        return cutoff

    def get_sparsity(self):
        active_elements = self.mask.sum() * np.prod(
            self.s_tensor.size()[self.sub_kernel_granularity:]).item()
        return (active_elements, 1 - active_elements / self.s_tensor.numel())

    def forward(self):
        if self.conv_tensor:
            return self.mask.view(
                *(self.mask.size() + (1, ) *
                  (4 - self.sub_kernel_granularity))) * self.s_tensor
        else:
            return self.mask * self.s_tensor

    def extra_repr(self):
        return 'full tensor size : {} , sparsity mask : {} , sub kernel granularity : {}'.format(
            self.s_tensor.size(), self.get_sparsity(),
            self.sub_kernel_granularity)