コード例 #1
0
class SubnetLinear(nn.Linear):
    # self.k is the % of weights remaining, a real number in [0,1]
    # self.popup_scores is a Parameter which has the same shape as self.weight
    # Gradients to self.weight, self.bias have been turned off.

    def __init__(self, in_features, out_features, bias=True):
        super(SubnetLinear, self).__init__(in_features,
                                           out_features,
                                           bias=True)
        self.popup_scores = Parameter(torch.Tensor(self.weight.shape))
        nn.init.kaiming_uniform_(self.popup_scores, a=math.sqrt(5))
        self.weight.requires_grad = False
        self.bias.requires_grad = False
        self.w = 0
        # self.register_buffer('w', None)

    def set_prune_rate(self, k):
        self.k = k

    def forward(self, x):
        # Get the subnetwork by sorting the scores.
        adj = GetSubnet.apply(self.popup_scores.abs(), self.k)

        # Use only the subnetwork in the forward pass.
        self.w = self.weight * adj
        x = F.linear(x, self.w, self.bias)

        return x
コード例 #2
0
class ActQuant_init(nn.Module):
    def __init__(self, act_bit=4, scale_coef=10.0, extern_init=False, init_model=nn.Sequential()):
        super(ActQuant_init, self).__init__()
        self.pwr_coef = 2**act_bit
        self.act_bit=act_bit
        self.scale_coef = Parameter(torch.ones(1)*scale_coef)
        if extern_init:
            param=list(init_model.parameters())
            if param[0]>0.1 and param[0]<10.0:
                self.scale_coef=Parameter(param[0])
            else:
                self.scale_coef=Parameter(torch.ones(1)*1.0)

    def forward(self, x):
        if self.act_bit==32:
            out=0.5*(x.abs() - (x-self.scale_coef.abs()).abs()+self.scale_coef.abs())/self.scale_coef.abs()
        else:
            out = 0.5*(x.abs() - (x-self.scale_coef.abs()).abs()+self.scale_coef.abs())
            out = RoundFn.apply(out / self.scale_coef.abs(), self.pwr_coef)
        return out*2.0
コード例 #3
0
class FRN(nn.Module):
    def __init__(self, num_features, eps=1e-6, is_eps_leanable=False):
        """
        weight = gamma, bias = beta
        beta, gamma:
            Variables of shape [1, 1, 1, C]. if TensorFlow
            Variables of shape [1, C, 1, 1]. if PyTorch
        eps: A scalar constant or learnable variable.
        """
        super(FRN, self).__init__()

        self.num_features = num_features
        self.init_eps = eps
        self.is_eps_leanable = is_eps_leanable

        self.weight = Parameter(torch.Tensor(num_features))
        self.bias = Parameter(torch.Tensor(num_features))
        if is_eps_leanable:
            self.eps = Parameter(torch.Tensor(1))
        else:
            self.register_buffer('eps', torch.Tensor([eps]))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.ones_(self.weight)
        nn.init.zeros_(self.bias)
        if self.is_eps_leanable:
            nn.init.constant_(self.eps, self.init_eps)

    def extra_repr(self):
        return 'num_features={num_features}, eps={init_eps}'.format(
            **self.__dict__)

    def forward(self, x):
        """
        0, 1, 2, 3 -> (B, H, W, C) in TensorFlow
        0, 1, 2, 3 -> (B, C, H, W) in PyTorch
        TensorFlow code
            nu2 = tf.reduce_mean(tf.square(x), axis=[1, 2], keepdims=True)
            x = x * tf.rsqrt(nu2 + tf.abs(eps))
            # This Code include TLU function max(y, tau)
            return tf.maximum(gamma * x + beta, tau)
        """
        # Compute the mean norm of activations per channel.
        nu2 = x.pow(2).mean(dim=[2, 3], keepdim=True)

        # Perform FRN.
        x = x * torch.rsqrt(nu2 + self.eps.abs())

        # Scale and Bias
        x = self.weight.view(1, self.num_features, 1, 1) * x + self.bias.view(
            1, self.num_features, 1, 1)
        # x = self.weight * x + self.bias
        return x
コード例 #4
0
ファイル: VBLayer.py プロジェクト: manuelhaussmann/bedl
class VBLinear(nn.Module):
    def __init__(self, in_features, out_features, prior_prec=10, map=True):
        super(VBLinear, self).__init__()
        self.n_in = in_features
        self.n_out = out_features

        self.prior_prec = prior_prec
        self.map = map

        self.bias = nn.Parameter(th.Tensor(out_features))
        self.mu_w = Parameter(th.Tensor(out_features, in_features))
        self.logsig2_w = nn.Parameter(th.Tensor(out_features, in_features))
        self.reset_parameters()

    def reset_parameters(self):
        # TODO: Adapt to the newest pytorch initializations
        stdv = 1. / math.sqrt(self.mu_w.size(1))
        self.mu_w.data.normal_(0, stdv)
        self.logsig2_w.data.zero_().normal_(-9, 0.001)  # var init via Louizos
        self.bias.data.zero_()

    def KL(self, loguniform=False):
        if loguniform:
            k1 = 0.63576
            k2 = 1.87320
            k3 = 1.48695
            log_alpha = self.logsig2_w - 2 * th.log(self.mu_w.abs() + 1e-8)
            kl = -th.sum(k1 * th.sigmoid(k2 + k3 * log_alpha) -
                         0.5 * F.softplus(-log_alpha) - k1)
        else:
            logsig2_w = self.logsig2_w.clamp(-11, 11)
            kl = 0.5 * (self.prior_prec *
                        (self.mu_w.pow(2) + logsig2_w.exp()) - logsig2_w - 1 -
                        np.log(self.prior_prec)).sum()
        return kl

    def forward(self, input):
        # Sampling free forward pass only if MAP prediction and no training rounds
        if self.map and not self.training:
            return F.linear(input, self.mu_w, self.bias)
        else:
            mu_out = F.linear(input, self.mu_w, self.bias)
            logsig2_w = self.logsig2_w.clamp(-11, 11)
            s2_w = logsig2_w.exp()
            var_out = F.linear(input.pow(2), s2_w) + 1e-8
            return mu_out + var_out.sqrt() * th.randn_like(mu_out)

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.n_in) + ' -> ' \
               + str(self.n_out) + ')'
コード例 #5
0
class SubnetConv(nn.Conv2d):
    # self.k is the % of weights remaining, a real number in [0,1]
    # self.popup_scores is a Parameter which has the same shape as self.weight
    # Gradients to self.weight, self.bias have been turned off by default.

    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride=1,
        padding=1,
        dilation=1,
        groups=1,
        bias=True,
    ):
        super(SubnetConv, self).__init__(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            dilation,
            groups,
            bias,
        )
        self.popup_scores = Parameter(torch.Tensor(self.weight.shape))
        nn.init.kaiming_uniform_(self.popup_scores, a=math.sqrt(5))

        self.weight.requires_grad = False
        if self.bias is not None:
            self.bias.requires_grad = False
        self.w = 0

    def set_prune_rate(self, k):
        self.k = k

    def forward(self, x):
        # Get the subnetwork by sorting the scores.
        adj = GetSubnet.apply(self.popup_scores.abs(), self.k)

        # Use only the subnetwork in the forward pass.
        self.w = self.weight * adj
        x = F.conv2d(x, self.w, self.bias, self.stride, self.padding,
                     self.dilation, self.groups)
        return x
コード例 #6
0
class RTML(nn.Module):
    def __init__(self, L=3, lamb=5):
        super(RTML, self).__init__()
        self.L = L
        self.N = len(att_names)
        self.lamb = lamb
        self.theta = Parameter(torch.Tensor(self.L, 300, 300))
        self.alpha = Parameter(torch.Tensor(self.N, self.L+1)) # L+1 is so to parameterize
                                                               # being smaller than norm lamb
        self.reset_parameters()

        self.att_emb = nn.Embedding(self.N, 300)
        if PREINIT:
            self.att_emb.weight.data = _load_vectors(att_names).cuda()
        else:
            _np_emb = np.random.randn(self.N, 300)
            _np_emb = _np_emb / np.square(_np_emb).sum(1)[:, None]
            self.att_emb.weight.data = torch.FloatTensor(_np_emb).cuda()

    def reset_parameters(self):
        for weight in [self.theta, self.alpha]:
            stdv = 1. / math.sqrt(weight.size(1))
            weight.data.uniform_(-stdv, stdv)

    def forward(self, word_embs):
        alpha_norm = self.alpha.abs().sum(1)
        alpha_constrained = self.lamb * self.alpha / alpha_norm.expand_as(self.alpha)

        R_flat = alpha_constrained[:, :-1] @ self.theta.view(self.L, -1)
        R = R_flat.view(self.N, 300, 300)

        s = 0
        preds = []
        for i, att_size in enumerate(dom_sizes):
            e = s + att_size
            att_embs = self.att_emb.weight[s:e].t()
            s = e

            p1 = word_embs @ R[i]
            p2 = p1 @ att_embs
            preds.append(p2)
        preds = torch.cat(preds, 1)
        return preds
コード例 #7
0
class Conv2dDPQ(nn.Conv2d):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 qmin=1e-3,
                 qmax=100,
                 dmin=1e-5,
                 dmax=10,
                 bias=True,
                 sign=True,
                 wbits=4,
                 abits=4,
                 mode=Qmodes.layer_wise):
        """
        :param d_init: the inital quantization stepsize (alpha)
        :param mode: Qmodes.layer_wise or Qmodes.kernel_wise
        :param xmax_init: the quantization range for whole weights 
        """

        super(Conv2dDPQ, self).__init__(in_channels=in_channels,
                                        out_channels=out_channels,
                                        kernel_size=kernel_size,
                                        stride=stride,
                                        padding=padding,
                                        dilation=dilation,
                                        groups=groups,
                                        bias=bias)

        self.qmin = qmin
        self.qmax = qmax
        self.dmin = dmin
        self.dmax = dmax
        self.q_mode = mode
        self.sign = sign
        self.nbits = wbits
        self.act_dpq = ActDPQ(signed=False, nbits=abits)
        self.alpha = Parameter(torch.Tensor(1))
        self.xmax = Parameter(torch.Tensor(1))
        self.weight.requires_grad_(True)
        if bias:
            self.bias.requires_grad_(True)
        self.register_buffer('init_state', torch.zeros(1))

    def get_nbits(self):
        abits = self.act_dpq.get_nbits()
        xmax = self.xmax.abs().item()
        alpha = self.alpha.abs().item()
        if self.sign:
            nbits = math.ceil(math.log(xmax / alpha + 1) / math.log(2) + 1)
        else:
            nbits = math.cell(math.log(xmax / alpha + 1) / math.log(2))

        self.nbits = nbits
        return abits, nbits

    def get_quan_filters(self, filters):
        if self.training and self.init_state == 0:
            Qp = 2**(self.nbits - 1) - 1
            self.xmax.data.copy_(filters.abs().max())
            self.alpha.data.copy_(self.xmax / Qp)
            # self.alpha[self.index].data.copy_(2 * filters.abs().mean() / math.sqrt(Qp))
            # self.xmax[self.index].data.copy_(self.alpha[self.index] * Qp)
            self.init_state.fill_(1)

        Qp = (self.xmax.detach() / self.alpha.detach()).abs().item()
        g = 1.0 / math.sqrt(filters.numel() * Qp)
        alpha = grad_scale(self.alpha, g)
        xmax = grad_scale(self.xmax, g)

        w = F.hardtanh(filters / xmax.abs(), -1, 1) * xmax.abs()
        w = w / alpha.abs()
        wq = round_pass(w) * alpha.abs()

        return wq

    def forward(self, x):
        if self.act_dpq is not None:
            x = self.act_dpq(x)

        wq = self.get_quan_filters(self.weight)
        return F.conv2d(x, wq, self.bias, self.stride, self.padding,
                        self.dilation, self.groups)
コード例 #8
0
class SparseTensor(nn.Module):
    def __init__(self,
                 tensor_size,
                 initial_sparsity,
                 sub_kernel_granularity=4):
        super(SparseTensor, self).__init__()
        self.s_tensor = Parameter(torch.Tensor(torch.Size(tensor_size)))
        self.initial_sparsity = initial_sparsity
        self.sub_kernel_granularity = sub_kernel_granularity

        assert self.s_tensor.dim() == 2 or self.s_tensor.dim(
        ) == 4, "can only do 2D or 4D sparse tensors"

        trailing_dimensions = [1] * (4 - sub_kernel_granularity)
        self.register_buffer(
            'mask', torch.Tensor(*(tensor_size[:sub_kernel_granularity])))

        self.normalize_coeff = np.prod(
            tensor_size[sub_kernel_granularity:]).item()

        self.conv_tensor = False if self.s_tensor.dim() == 2 else True

        self.mask.zero_()
        flat_mask = self.mask.view(-1)
        indices = np.arange(flat_mask.size(0))
        np.random.shuffle(indices)
        flat_mask[indices[:int((1 - initial_sparsity) * flat_mask.size(0) +
                               0.1)]] = 1

        self.grown_indices = None
        self.init_parameters()
        self.reinitialize_unused()

        self.tensor_sign = torch.sign(self.s_tensor.data.view(-1))

    def reinitialize_unused(self, reinitialize_unused_to_zero=True):
        unused_positions = (self.mask < 0.5)
        if reinitialize_unused_to_zero:
            self.s_tensor.data[unused_positions] = torch.zeros(
                self.s_tensor.data[unused_positions].size()).to(
                    self.s_tensor.device)
        else:
            if self.conv_tensor:
                n = self.s_tensor.size(0) * self.s_tensor.size(
                    2) * self.s_tensor.size(3)
                self.s_tensor.data[unused_positions] = torch.zeros(
                    self.s_tensor.data[unused_positions].size()).normal_(
                        0, math.sqrt(2. / n)).to(self.s_tensor.device)
            else:
                stdv = 1. / math.sqrt(self.s_tensor.size(1))
                self.s_tensor.data[unused_positions] = torch.zeros(
                    self.s_tensor.data[unused_positions].size()).normal_(
                        0, stdv).to(self.s_tensor.device)

    def init_parameters(self):
        stdv = 1 / math.sqrt(np.prod(self.s_tensor.size()[1:]))

        self.s_tensor.data.uniform_(-stdv, stdv)

    def prune_sign_change(self,
                          reinitialize_unused_to_zero=True,
                          enable_print=False):
        W_flat = self.s_tensor.data.view(-1)

        new_tensor_sign = torch.sign(W_flat)
        mask_flat = self.mask.view(-1)

        mask_indices = torch.nonzero(mask_flat > 0.5).view(-1)

        sign_change_indices = mask_indices[(
            (new_tensor_sign[mask_indices] *
             self.tensor_sign[mask_indices].to(new_tensor_sign.device)) <
            -0.5).nonzero().view(-1)]

        mask_flat[sign_change_indices] = 0
        self.reinitialize_unused(reinitialize_unused_to_zero)

        cutoff = sign_change_indices.numel()

        if enable_print:
            print('pruned {}  connections'.format(cutoff))
        if self.grown_indices is not None and enable_print:
            overlap = np.intersect1d(sign_change_indices.cpu().numpy(),
                                     self.grown_indices.cpu().numpy())
            print('pruned {} ({} %) just grown weights'.format(
                overlap.size, overlap.size * 100.0 / self.grown_indices.size(0)
                if self.grown_indices.size(0) > 0 else 0.0))

        self.tensor_sign = new_tensor_sign
        return sign_change_indices

    def prune_small_connections(self,
                                prune_fraction,
                                reinitialize_unused_to_zero=True):
        if self.conv_tensor and self.sub_kernel_granularity < 4:
            W_flat = self.s_tensor.abs().sum(
                list(np.arange(self.sub_kernel_granularity,
                               4))).view(-1) / self.normalize_coeff
        else:
            W_flat = self.s_tensor.data.view(-1)

        mask_flat = self.mask.view(-1)

        mask_indices = torch.nonzero(mask_flat > 0.5).view(-1)

        W_masked = W_flat[mask_indices]

        sorted_W_indices = torch.sort(torch.abs(W_masked))[1]

        cutoff = int(prune_fraction * W_masked.numel()) + 1

        mask_flat[mask_indices[sorted_W_indices[:cutoff]]] = 0
        self.reinitialize_unused(reinitialize_unused_to_zero)

        #        print('pruned {}  connections'.format(cutoff))
        #        if self.grown_indices is not None:
        #            overlap = np.intersect1d(mask_indices[sorted_W_indices[:cutoff]].cpu().numpy(),self.grown_indices.cpu().numpy())
        #print('pruned {} ({} %) just grown weights'.format(overlap.size,overlap.size * 100.0 / self.grown_indices.size(0)))

        return mask_indices[sorted_W_indices[:cutoff]]

    def prune_threshold(self, threshold, reinitialize_unused_to_zero=True):
        if self.conv_tensor and self.sub_kernel_granularity < 4:
            W_flat = self.s_tensor.abs().sum(
                list(np.arange(self.sub_kernel_granularity,
                               4))).view(-1) / self.normalize_coeff
        else:
            W_flat = self.s_tensor.data.view(-1)

        mask_flat = self.mask.view(-1)

        mask_indices = torch.nonzero(mask_flat > 0.5).view(-1)

        W_masked = W_flat[mask_indices]

        prune_indices = (W_masked.abs() < threshold).nonzero().view(-1)

        if mask_indices.size(0) == prune_indices.size(0):
            print('removing all. keeping one')
            prune_indices = prune_indices[1:]

        mask_flat[mask_indices[prune_indices]] = 0

        #       if mask_indices.numel() > 0 :
        #           print('pruned {}/{}({:.2f})  connections'.format(prune_indices.numel(),mask_indices.numel(),prune_indices.numel()/mask_indices.numel()))

        #        if self.grown_indices is not None and self.grown_indices.size(0) != 0 :
        #            overlap = np.intersect1d(mask_indices[prune_indices].cpu().numpy(),self.grown_indices.cpu().numpy())
        #            print('pruned {} ({} %) just grown weights'.format(overlap.size,overlap.size * 100.0 / self.grown_indices.size(0)))

        self.reinitialize_unused(reinitialize_unused_to_zero)

        return mask_indices[prune_indices]

    def grow_random(self,
                    grow_fraction,
                    pruned_indices=None,
                    enable_print=False,
                    n_to_add=None):
        mask_flat = self.mask.view(-1)
        mask_zero_indices = torch.nonzero(mask_flat < 0.5).view(-1)
        if pruned_indices is not None:
            cutoff = pruned_indices.size(0)
            mask_zero_indices = torch.Tensor(
                np.setdiff1d(mask_zero_indices.cpu().numpy(),
                             pruned_indices.cpu().numpy())).long().to(
                                 mask_zero_indices.device)
        else:
            cutoff = int(grow_fraction * mask_zero_indices.size(0))

        if n_to_add is not None:
            cutoff = n_to_add

        if mask_zero_indices.numel() < cutoff:
            print('******no place to grow {} connections, growing {} instead'.
                  format(cutoff, mask_zero_indices.numel()))
            cutoff = mask_zero_indices.numel()

        if enable_print:
            print('grown {}  connections'.format(cutoff))

        self.grown_indices = mask_zero_indices[torch.randperm(
            mask_zero_indices.numel())][:cutoff]
        mask_flat[self.grown_indices] = 1

        return cutoff

    def get_sparsity(self):
        active_elements = self.mask.sum() * np.prod(
            self.s_tensor.size()[self.sub_kernel_granularity:]).item()
        return (active_elements, 1 - active_elements / self.s_tensor.numel())

    def forward(self):
        if self.conv_tensor:
            return self.mask.view(
                *(self.mask.size() + (1, ) *
                  (4 - self.sub_kernel_granularity))) * self.s_tensor
        else:
            return self.mask * self.s_tensor

    def extra_repr(self):
        return 'full tensor size : {} , sparsity mask : {} , sub kernel granularity : {}'.format(
            self.s_tensor.size(), self.get_sparsity(),
            self.sub_kernel_granularity)
コード例 #9
0
class group_relaxed_L0Dense(Module):
    """Implementation of TFL regularization for the input units of a fully connected layer"""
    def __init__(self,
                 in_features,
                 out_features,
                 bias=True,
                 lamba=1.,
                 beta=4.,
                 weight_decay=1.,
                 **kwargs):
        """
		:param in_features: input dimensionality
		:param out_features: output dimensionality
		:param bias: whether we use bias
		:param lamba: strength of the TF1 regularization
		"""
        super(group_relaxed_L0Dense, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(in_features, out_features))
        self.u = torch.rand(in_features, out_features)
        self.u = self.u.to('cuda')
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.lamba = lamba
        self.beta = beta
        self.weight_decay = weight_decay
        self.floatTensor = torch.FloatTensor if not torch.cuda.is_available(
        ) else torch.cuda.FloatTensor
        self.reset_parameters()
        print(self)

    def reset_parameters(self):
        init.kaiming_normal(self.weight, mode='fan_out')

        if self.bias is not None:
            self.bias.data.normal_(0, 1e-2)

    def constrain_parameters(self, **kwargs):
        #self.weight.data = F.normalize(self.weight.data, p=2, dim=1)
        m = Hardshrink((2 * self.lamba / self.beta)**(1 / 2))
        self.u.data = m(self.weight.data)

    def grow_beta(self, growth_factor):
        self.beta = self.beta * growth_factor

    def _reg_w(self, **kwargs):
        logpw = -self.beta * torch.sum(
            0.5 * self.weight.add(-self.u).pow(2)) - self.lamba * np.sqrt(
                self.out_features) * torch.sum(
                    torch.pow(torch.sum(self.weight.pow(2), 1), 0.5))
        logpb = 0
        if self.bias is not None:
            logpb = -torch.sum(self.weight_decay * .5 * (self.bias.pow(2)))
        return logpw + logpb

    def regularization(self):
        return self._reg_w()

    def count_zero_u(self):
        total = np.prod(self.u.size())
        zero = total - self.u.nonzero().size(0)
        return zero

    def count_zero_w(self):
        return torch.sum((self.weight.abs() < 1e-5).int()).item()

    def count_weight(self):
        return np.prod(self.u.size())

    def count_active_neuron(self):
        return torch.sum(
            torch.sum(self.weight.abs() / self.out_features, 1) > 1e-5).item()

    def count_total_neuron(self):
        return self.in_features

    def count_expected_flops_and_l0(self):
        ppos = torch.sum(self.weight.abs() > 0.000001).item()
        expected_flops = (2 * ppos - 1) * self.out_features
        expected_l0 = ppos * self.out_features
        if self.bias is not None:
            expected_flops += self.out_features
            expected_l0 += self.out_features
        return expected_flops, expected_l0

    def forward(self, input):
        output = input.mm(self.weight)
        if self.bias is not None:
            output.add_(self.bias.view(1, self.out_features).expand_as(output))
        return output

    def __repr__(self):
        return self.__class__.__name__+' (' \
         + str(self.in_features) + ' -> ' \
         + str(self.out_features) + ', lambda: ' \
         + str(self.lamba) + ')'
コード例 #10
0
class MAPConv2d(Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
                 weight_decay=1.,
                 **kwargs):
        super(MAPConv2d, self).__init__()
        self.weight_decay = weight_decay
        self.floatTensor = torch.FloatTensor if not torch.cuda.is_available(
        ) else torch.cuda.FloatTensor
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = pair(kernel_size)
        self.stride = pair(stride)
        self.padding = pair(padding)
        self.dilation = pair(dilation)
        self.output_padding = pair(0)
        self.groups = groups
        self.weight = Parameter(
            torch.Tensor(out_channels, in_channels // groups,
                         *self.kernel_size))
        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()
        self.input_shape = None
        print(self)

    def reset_parameters(self):
        init.kaiming_normal(self.weight, mode='fan_in')

        if self.bias is not None:
            self.bias.data.normal_(0, 1e-2)

    def constrain_parameters(self, thres_std=1.):
        pass

    def _reg_w(self, **kwargs):
        logpw = -torch.sum(self.weight_decay * .5 * (self.weight.pow(2)))
        logpb = 0
        if self.bias is not None:
            logpb = -torch.sum(self.weight_decay * .5 * (self.bias.pow(2)))
        return logpw + logpb

    def regularization(self):
        return self._reg_w()

    def count_zero_w(self):
        return torch.sum((self.weight.abs() < 1e-5).int()).item()

    def count_zero_u(self):
        return 0

    def count_weight(self):
        return np.prod(self.weight.size())

    def count_active_neuron(self):
        return torch.sum((torch.sum(self.weight.abs(), 3).sum(2).sum(1) /
                          (self.in_channels * self.kernel_size[0] *
                           self.kernel_size[1])) > 1e-5).item()

    def count_total_neuron(self):
        return self.out_channels

    def count_expected_flops_and_l0(self):
        ppos = self.out_channels
        n = self.kernel_size[0] * self.kernel_size[
            1] * self.in_channels  # vector_length
        flops_per_instance = n + (n - 1
                                  )  # (n: multiplications and n-1: additions)

        num_instances_per_filter = (
            (self.input_shape[1] - self.kernel_size[0] + 2 * self.padding[0]) /
            self.stride[0]) + 1  # for rows
        num_instances_per_filter *= (
            (self.input_shape[2] - self.kernel_size[1] + 2 * self.padding[1]) /
            self.stride[1]) + 1  # multiplying with cols

        flops_per_filter = num_instances_per_filter * flops_per_instance
        expected_flops = flops_per_filter * ppos  # multiply with number of filters
        expected_l0 = n * ppos

        if self.bias is not None:
            # since the gate is applied to the output we also reduce the bias computation
            expected_flops += num_instances_per_filter * ppos
            expected_l0 += ppos

        return expected_flops, expected_l0

    def forward(self, input_):
        if self.input_shape is None:
            self.input_shape = input_.size()
        output = F.conv2d(input_, self.weight, self.bias, self.stride,
                          self.padding, self.dilation, self.groups)
        return output

    def __repr__(self):
        s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size} '
             ', stride={stride}, weight_decay={weight_decay}')
        if self.padding != (0, ) * len(self.padding):
            s += ', padding={padding}'
        if self.dilation != (1, ) * len(self.dilation):
            s += ', dilation={dilation}'
        if self.output_padding != (0, ) * len(self.output_padding):
            s += ', output_padding={output_padding}'
        if self.groups != 1:
            s += ', groups={groups}'
        if self.bias is None:
            s += ', bias=False'
        s += ')'
        return s.format(name=self.__class__.__name__, **self.__dict__)
コード例 #11
0
class MAPDense(Module):
    def __init__(self,
                 in_features,
                 out_features,
                 bias=True,
                 weight_decay=1.,
                 **kwargs):
        super(MAPDense, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(in_features, out_features))
        self.weight_decay = weight_decay
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.floatTensor = torch.FloatTensor if not torch.cuda.is_available(
        ) else torch.cuda.FloatTensor
        self.reset_parameters()
        print(self)

    def reset_parameters(self):
        init.kaiming_normal(self.weight, mode='fan_out')

        if self.bias is not None:
            self.bias.data.normal_(0, 1e-2)

    def constrain_parameters(self, **kwargs):
        pass

    def _reg_w(self, **kwargs):
        logpw = -torch.sum(self.weight_decay * .5 * (self.weight.pow(2)))
        logpb = 0
        if self.bias is not None:
            logpb = -torch.sum(self.weight_decay * .5 * (self.bias.pow(2)))
        return logpw + logpb

    def regularization(self):
        return self._reg_w()

    def count_zero_w(self):
        return torch.sum((self.weight.abs() < 1e-5).int()).item()

    def count_zero_u(self):
        return 0

    def count_weight(self):
        return np.prod(self.weight.size())

    def count_active_neuron(self):
        return torch.sum(
            torch.sum(self.weight.abs() / self.out_features, 1) > 1e-5).item()

    def count_total_neuron(self):
        return self.in_features

    def count_expected_flops_and_l0(self):
        # dim_in multiplications and dim_in - 1 additions for each output neuron for the weights
        # + the bias addition for each neuron
        # total_flops = (2 * in_features - 1) * out_features + out_features
        expected_flops = (2 * self.in_features - 1) * self.out_features
        expected_l0 = self.in_features * self.out_features
        if self.bias is not None:
            expected_flops += self.out_features
            expected_l0 += self.out_features
        return expected_flops, expected_l0

    def forward(self, input):
        output = input.mm(self.weight)
        if self.bias is not None:
            output.add_(self.bias.view(1, self.out_features).expand_as(output))
        return output

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
            + str(self.in_features) + ' -> ' \
            + str(self.out_features) + ', weight_decay: ' \
            + str(self.weight_decay) + ')'
コード例 #12
0
ファイル: layers.py プロジェクト: timko98/hydra
class SubnetConv(nn.Conv2d):
    # self.k is the % of weights remaining, a real number in [0,1]
    # self.popup_scores is a Parameter which has the same shape as self.weight
    # Gradients to self.weight, self.bias have been turned off by default.

    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride=1,
        padding=0,
        dilation=1,
        groups=1,
        bias=True,
    ):
        super(SubnetConv, self).__init__(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            dilation,
            groups,
            bias,
        )
        # Weight pruning
        # self.popup_scores = Parameter(torch.Tensor(self.weight.shape))
        # Channel Finetuning or Resume Pruning
        # self.popup_scores = Parameter(torch.Tensor(torch.Size([1,self.weight.shape[1],1,1])))
        # Channel Pruning
        self.popup_scores = Parameter(
            torch.Tensor(torch.Size([self.weight.shape[0], 1, 1, 1])))

        nn.init.kaiming_uniform_(self.popup_scores, a=math.sqrt(5))

        self.weight.requires_grad = False
        if self.bias is not None:
            self.bias.requires_grad = False
        self.w = 0

    def set_prune_rate(self, k):
        self.k = k

    def forward(self, x):
        """ Unstructured comparison
        remaining_weights = int(self.k * len(self.weight.flatten()))
        idx_same_top_weights_scores = list(
            set(torch.topk(self.weight.abs().flatten(), remaining_weights).indices.tolist()).intersection(
                set(torch.topk(self.popup_scores.abs().flatten(), remaining_weights).indices.tolist())))
        num_remaining_weights = len(idx_same_top_weights_scores)
        print(
            f"SubnetConv: Number of same indices for scores and weights that are left after pruning: "
            f"{num_remaining_weights}. These are {float(num_remaining_weights / remaining_weights)} percent of the "
            f"weights kept.")
        """
        """ Structured Comparison
        remaining_filters = int(self.k * self.weight.shape[0])
        idx_same_top_weights_scores = list(set(
            torch.topk(torch.linalg.norm(self.weight.abs().reshape(self.weight.shape[0], -1), 1, dim=1),
                       remaining_filters).indices.tolist()).intersection(
            torch.topk(torch.linalg.norm(self.popup_scores.abs().reshape(self.popup_scores.shape[0], -1), 1, dim=1),
                       remaining_filters).indices.tolist()))
        num_remaining_filters = len(idx_same_top_weights_scores)
        print(
            f"SubnetConv: Number of same indices for filters that are left after pruning using scores or weights : "
            f"{num_remaining_filters}. These are {float(num_remaining_filters / remaining_filters)} percent of the "
            f"filters kept.")
        """
        """ Channel Prune VGG16
        global conv_nr
        if conv_nr == 13:
            conv_nr = 1
        else:
            conv_nr += 1
        # Get the subnetwork by sorting the scores.
        mask_conv_50 = [1.0, 1.0, 0.984375, 1.0, 1.0, 0.98828125, 0.98046875, 1.0, 0.96875, 0.359375, 0.099609375, 0.1015625, 0.099609375]
        mask_conv_10 = [1.0, 0.5, 0.46875, 0.4921875, 0.484375, 0.4765625, 0.5, 0.5, 0.48242188, 0.05078125, 0.0234375, 0.015625, 0.015625]
        k = mask_conv_10[conv_nr-1]
        if conv_nr == 1:
            adj = GetSubnet.apply(self.popup_scores.abs(), 1)
        else:
            adj = GetSubnet.apply(self.popup_scores.abs(), self.k)
        """
        global conv_nr
        if conv_nr == 28:
            conv_nr = 1
        else:
            conv_nr += 1
        """
        mask_wrn_50 = [1, 0.5, 0.171875, 0.5, 0.5625, 0.5, 0.359375, 0.40625, 0.375, 0.1875, 0.390625, 0.4453125, 0.390625, 0.328125, 0.171875, 0.4765625, 0.3046875, 0.140625, 0.2265625,0.640625, 0.49609375, 0.640625, 0.6015625, 0.6796875, 0.46875, 0.52734375, 0.50390625, 0.48825125]
        # Add mask for 0.1 Channel pruning here
        mask_wrn_10 = None
        k = mask_wrn_50[conv_nr-1]
        adj = GetSubnet.apply(self.popup_scores.abs(), k)
        """
        if conv_nr == 1:
            adj = GetSubnet.apply(self.popup_scores.abs(), 1)
        else:
            adj = GetSubnet.apply(self.popup_scores.abs(), self.k)

        # Use only the subnetwork in the forward pass.
        self.w = self.weight * adj
        x = F.conv2d(x, self.w, self.bias, self.stride, self.padding,
                     self.dilation, self.groups)
        return x
コード例 #13
0
ファイル: linear.py プロジェクト: jiangyuang/PruneFL
class DenseLinear(nn.Module):
    __constants__ = ['in_features', 'out_features']

    def __init__(self,
                 in_features,
                 out_features,
                 use_bias=True,
                 use_mask=True,
                 **kwargs):
        super(DenseLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(out_features, in_features))
        if use_bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters(**kwargs)

        # self._initial_weight = self.weight.data.clone()
        # self._initial_bias = self.bias.data.clone() if use_bias else None
        self.use_mask = use_mask
        self.mask = torch.ones_like(self.weight, dtype=torch.bool)

    def reset_parameters(self, **kwargs):
        if len(kwargs.keys()) == 0:
            # default init, see https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear
            init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        else:
            init.kaiming_uniform_(self.weight, **kwargs)

        if self.bias is not None:
            # default init, see https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

    def forward(self, inp: torch.Tensor):
        masked_weight = self.weight * self.mask if self.use_mask else self.weight
        return nn.functional.linear(inp, masked_weight, self.bias)

    def extra_repr(self):
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None)

    def prune_by_threshold(self, thr):
        self.mask *= (self.weight.abs() >= thr)

    def prune_by_rank(self, rank):
        if rank == 0:
            return
        weight_val = self.weight[self.mask == 1.]
        sorted_abs_weight = weight_val.abs().sort()[0]
        thr = sorted_abs_weight[rank]
        self.prune_by_threshold(thr)

    def prune_by_pct(self, pct):
        prune_idx = int(self.num_weight * pct)
        self.prune_by_rank(prune_idx)

    def retain_by_threshold(self, thr):
        self.mask *= (self.weight.abs() >= thr)

    def retain_by_rank(self, rank):
        weights_val = self.weight[self.mask == 1.]
        sorted_abs_weights = weights_val.abs().sort(descending=True)[0]
        thr = sorted_abs_weights[rank]
        self.retain_by_threshold(thr)

    def random_prune_by_pct(self, pct):
        prune_idx = int(self.num_weight * pct)
        rand = torch.rand(size=self.mask.size(), device=self.mask.device)
        rand_val = rand[self.mask == 1]
        sorted_abs_rand = rand_val.sort()[0]
        thr = sorted_abs_rand[prune_idx]
        self.mask *= (rand >= thr)

    # def reinitialize(self):
    #     self.weight = Parameter(self._initial_weight)
    #     if self._initial_bias is not None:
    #         self.bias = Parameter(self._initial_bias)

    def to_sparse(self, transpose=False) -> SparseLinear:
        """
        by chance, some entries with mask = 1 can have a 0 value. Thus, the to_sparse methods give a different size
        there's no efficient way to solve it yet
        """
        sparse_bias = None if self.bias is None else self.bias.reshape((-1, 1))
        sparse_linear = SparseLinear((self.weight * self.mask).to_sparse(),
                                     sparse_bias, self.mask)
        if transpose:
            sparse_linear.transpose = True
        return sparse_linear

    def move_data(self, device: torch.device):
        self.mask = self.mask.to(device)

    def to(self, *args, **kwargs):
        device = torch._C._nn._parse_to(*args, **kwargs)[0]

        if device is not None:
            self.move_data(device)

        return super(DenseLinear, self).to(*args, **kwargs)

    @property
    def num_weight(self) -> int:
        return self.mask.sum().item()
コード例 #14
0
class FilterStripe(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1):
        super(FilterStripe, self).__init__(in_channels,
                                           out_channels,
                                           kernel_size,
                                           stride,
                                           kernel_size // 2,
                                           groups=1,
                                           bias=False)
        self.BrokenTarget = None
        self.FilterSkeleton = Parameter(torch.ones(self.out_channels,
                                                   self.kernel_size[0],
                                                   self.kernel_size[1]),
                                        requires_grad=True)

    def forward(self, x):
        if self.BrokenTarget is not None:
            out = torch.zeros(x.shape[0], self.FilterSkeleton.shape[0],
                              int(np.ceil(x.shape[2] / self.stride[0])),
                              int(np.ceil(x.shape[3] / self.stride[1])))
            if x.is_cuda:
                out = out.cuda()
            x = F.conv2d(x, self.weight)
            l, h = 0, 0
            for i in range(self.BrokenTarget.shape[0]):
                for j in range(self.BrokenTarget.shape[1]):
                    h += self.FilterSkeleton[:, i, j].sum().item()
                    out[:, self.FilterSkeleton[:, i, j]] += self.shift(
                        x[:, l:h], i,
                        j)[:, :, ::self.stride[0], ::self.stride[1]]
                    l += self.FilterSkeleton[:, i, j].sum().item()
            return out
        else:
            return F.conv2d(x,
                            self.weight * self.FilterSkeleton.unsqueeze(1),
                            stride=self.stride,
                            padding=self.padding,
                            groups=self.groups)

    def prune_in(self, in_mask=None):
        self.weight = Parameter(self.weight[:, in_mask])
        self.in_channels = in_mask.sum().item()

    def prune_out(self, threshold):
        out_mask = (self.FilterSkeleton.abs() > threshold).sum(dim=(1, 2)) != 0
        if out_mask.sum() == 0:
            out_mask[0] = True
        self.weight = Parameter(self.weight[out_mask])
        self.FilterSkeleton = Parameter(self.FilterSkeleton[out_mask],
                                        requires_grad=True)
        self.out_channels = out_mask.sum().item()
        return out_mask

    def _break(self, threshold):
        self.weight = Parameter(self.weight * self.FilterSkeleton.unsqueeze(1))
        self.FilterSkeleton = Parameter(
            (self.FilterSkeleton.abs() > threshold), requires_grad=False)
        if self.FilterSkeleton.sum() == 0:
            self.FilterSkeleton.data[0][0][0] = True
        self.out_channels = self.FilterSkeleton.sum().item()
        self.BrokenTarget = self.FilterSkeleton.sum(dim=0)
        self.kernel_size = (1, 1)
        self.weight = Parameter(
            self.weight.permute(2, 3, 0,
                                1).reshape(-1, self.in_channels, 1,
                                           1)[self.FilterSkeleton.permute(
                                               1, 2, 0).reshape(-1)])

    def update_skeleton(self, sr, threshold):
        self.FilterSkeleton.grad.data.add_(
            sr * torch.sign(self.FilterSkeleton.data))
        mask = self.FilterSkeleton.data.abs() > threshold
        self.FilterSkeleton.data.mul_(mask)
        self.FilterSkeleton.grad.data.mul_(mask)
        out_mask = mask.sum(dim=(1, 2)) != 0
        return out_mask

    def shift(self, x, i, j):
        return F.pad(x, (self.BrokenTarget.shape[0] // 2 - j,
                         j - self.BrokenTarget.shape[0] // 2,
                         self.BrokenTarget.shape[0] // 2 - i,
                         i - self.BrokenTarget.shape[1] // 2), 'constant', 0)

    def extra_repr(self):
        s = (
            '{BrokenTarget},{in_channels}, {out_channels}, kernel_size={kernel_size}'
            ', stride={stride}')
        return s.format(**self.__dict__)
コード例 #15
0
ファイル: CLTLayer.py プロジェクト: manuelhaussmann/bedl
class CLTLinear(nn.Module):
    def __init__(self,
                 in_features,
                 out_features,
                 prior_prec=10,
                 relu_act=True,
                 elu_act=False):
        super(CLTLinear, self).__init__()
        self.n_in = in_features
        self.n_out = out_features

        self.prior_prec = prior_prec

        assert not (
            relu_act and elu_act
        )  # A single layer can only do either relu or elu activation
        self.relu_act = relu_act
        self.elu_act = elu_act

        self.bias = nn.Parameter(th.Tensor(out_features))
        self.mu_w = Parameter(th.Tensor(out_features, in_features))
        self.logsig2_w = nn.Parameter(th.Tensor(out_features, in_features))
        self.reset_parameters()

    def reset_parameters(self):
        # TODO: Adapt to the newest pytorch initializations
        stdv = 1. / math.sqrt(self.mu_w.size(1))
        self.mu_w.data.normal_(0, stdv)
        self.logsig2_w.data.zero_().normal_(-9, 0.001)
        self.bias.data.zero_()

    def KL(self, loguniform=False):
        if loguniform:
            k1 = 0.63576
            k2 = 1.87320
            k3 = 1.48695
            log_alpha = self.logsig2_w - 2 * th.log(self.mu_w.abs() + 1e-8)
            kl = -th.sum(k1 * F.sigmoid(k2 + k3 * log_alpha) -
                         0.5 * F.softplus(-log_alpha) - k1)
        else:
            logsig2_w = self.logsig2_w.clamp(-11, 11)
            kl = 0.5 * (self.prior_prec *
                        (self.mu_w.pow(2) + logsig2_w.exp()) - logsig2_w - 1 -
                        np.log(self.prior_prec)).sum()
        return kl

    def cdf(self, x, mu=0., sig=1.):
        return 0.5 * (1 + th.erf((x - mu) / (sig * math.sqrt(2))))

    def pdf(self, x, mu=0., sig=1.):
        return (1 / (math.sqrt(2 * math.pi) * sig)) * th.exp(-0.5 * (
            (x - mu) / sig).pow(2))

    def relu_moments(self, mu, sig):
        alpha = mu / sig
        cdf = self.cdf(alpha)
        pdf = self.pdf(alpha)
        relu_mean = mu * cdf + sig * pdf
        relu_var = (sig.pow(2) +
                    mu.pow(2)) * cdf + mu * sig * pdf - relu_mean.pow(2)
        relu_var.clamp_(1e-8)  # Avoid negative variance due to numerics
        return relu_mean, relu_var

    def elu_moments_orig(self, mu, sig):
        # the original method without simplifications
        sig2 = sig.pow(2)
        elu_mean = th.exp(mu.clamp_max(10) + sig2 / 2) * self.cdf(
            -(mu + sig2) / sig) - self.cdf(-mu / sig)
        elu_mean += mu * self.cdf(mu / sig) + sig * self.pdf(mu / sig)
        elu_var = th.exp(2 * mu.clamp_max(10) + 2 * sig2) * self.cdf(
            -(mu + 2 * sig2) / sig)
        elu_var += -2 * th.exp(mu.clamp_max(10) + sig2 / 2) * self.cdf(
            -(mu + sig2) / sig)
        elu_var += self.cdf(-mu / sig)
        elu_var += (sig2 + mu.pow(2)) * self.cdf(
            mu / sig) + mu * sig * self.pdf(mu / sig)
        elu_var += -elu_mean.pow(2)
        elu_var.clamp_min_(1e-8)  # Avoid negative variance due to numerics
        return elu_mean, elu_var

    def elu_moments(self, mu, sig):
        # NOTE: For now it is without alpha or the selu extension!
        # Note: Takes roughly 3x as much time as the relu
        # Clamp the mus to avoid problems in the expectation
        sig2 = sig.pow(2)
        alpha = mu / sig

        cdf_alpha = self.cdf(alpha)
        pdf_alpha = self.pdf(alpha)
        cdf_malpha = 1 - cdf_alpha
        cdf_malphamsig = self.cdf(-alpha - sig)

        elu_mean = th.exp(mu.clamp_max(10) +
                          sig2 / 2) * cdf_malphamsig - cdf_malpha
        elu_mean += mu * cdf_alpha + sig * pdf_alpha

        elu_var = th.exp(2 * mu.clamp_max(10) + 2 * sig2) * self.cdf(-alpha -
                                                                     2 * sig)
        elu_var += -2 * th.exp(mu.clamp_max(10) + sig2 / 2) * cdf_malphamsig
        elu_var += cdf_malpha
        elu_var += (sig2 + mu.pow(2)) * cdf_alpha + mu * sig * pdf_alpha
        elu_var += -elu_mean.pow(2)
        elu_var.clamp_min_(1e-8)  # Avoid negative variance due to numerics
        return elu_mean, elu_var

    def forward(self, mu_inp, var_inp=None):
        s2_w = self.logsig2_w.exp()

        mu_out = F.linear(mu_inp, self.mu_w, self.bias)
        if var_inp is None:
            var_out = F.linear(mu_inp.pow(2), s2_w) + 1e-8
        else:
            var_out = F.linear(var_inp + mu_inp.pow(2), s2_w) + F.linear(
                var_inp, self.mu_w.pow(2)) + 1e-8

        if self.relu_act:
            mu_out, var_out = self.relu_moments(mu_out, var_out.sqrt())

        if self.elu_act:
            mu_out, var_out = self.elu_moments(mu_out, var_out.sqrt())

        return mu_out, var_out  # + 1e-8 Already provided in the moment computation

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.n_in) + ' -> ' \
               + str(self.n_out) \
               + f", activation={self.relu_act or self.elu_act}" \
               + f" ({'relu' if self.relu_act else ('elu' if self.elu_act else '')}))"
コード例 #16
0
ファイル: linear.py プロジェクト: Bertie97/pyctlib
class Linear(nn.Module):

    __constants__ = ['in_features', 'out_features']
    in_features: int
    out_features: int
    weight: Tensor

    def __init__(self, in_features: int, out_features: int, bias: bool = True, activation="ReLU", hidden_dim=None, hidden_activation="ReLU") -> None:
        super(Linear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.hidden_dim = hidden_dim
        self.hidden_activation = hidden_activation
        if hidden_dim is None:
            self.dims = vector(in_features, out_features)
            self.weight = Parameter(torch.zeros(out_features, in_features))
            if bias:
                self.bias = Parameter(torch.zeros(out_features))
            else:
                self.register_parameter('bias', None)
            self.activation = get_activation_layer(activation)
        else:
            self.dims = vector(in_features, *vector(hidden_dim), out_features)
            self.weight = nn.ParameterList(self.dims.map_k(lambda in_dim, out_dim: Parameter(torch.zeros(out_dim, in_dim)), 2))
            if bias:
                self.bias = nn.ParameterList(self.dims.map_k(lambda in_dim, out_dim: Parameter(torch.zeros(out_dim)), 2))
            else:
                self.register_parameter('bias', None)
            self.activation = vector(get_activation_layer(hidden_activation) for _ in range(len(hidden_dim)))
            self.activation.append(get_activation_layer(activation))

        self.reset_parameters()

    def reset_parameters(self) -> None:
        if self.hidden_dim is None:
            if isinstance(self.activation, torch.nn.ReLU) or self.activation == torch.relu:
                init.kaiming_normal_(self.weight, a=0, mode='fan_in', nonlinearity='relu')
            else:
                init.xavier_normal_(self.weight)
        else:
            for a, w in zip(self.activation, self.weight):
                if isinstance(a, torch.nn.ReLU) or a == torch.relu:
                    init.kaiming_normal_(w, a=0, mode='fan_in', nonlinearity='relu')
                else:
                    init.xavier_normal_(w)

    def forward(self, input: Tensor) -> Tensor:
        if self.hidden_dim is None:
            if self.activation is None:
                return F.linear(input, self.weight, self.bias)
            else:
                return self.activation(F.linear(input, self.weight, self.bias))
        else:
            h = input
            if self.bias is None:
                for w, a in zip(self.weight, self.activation):
                    h = a(F.linear(h, w, None))
            else:
                for w, b, a in zip(self.weight, self.bias, self.activation):
                    h = a(F.linear(h, w, b))
            return h

    def extra_repr(self) -> str:
        if self.activation is None:
            return 'in_features={}, out_features={}, bias={}'.format(self.in_features, self.out_features, self.bias is not None)
        elif isinstance(self.activation, vector):
            ret = 'in_features={}, out_features={}, bias={}, activation={}\n'.format(self.in_features, self.out_features, self.bias is not None, self.activation.map(lambda x: touch(lambda: x.__name__, str(x))))
            ret += "{}".format(self.in_features)
            for d, a in zip(self.dims[1:], self.activation):
                ret += '->{}->{}'.format(d, touch(lambda: a.__name__, str(a)))
            return ret
        else:
            ret = 'in_features={}, out_features={}, bias={}, activation={}'.format(self.in_features, self.out_features, self.bias is not None, touch(lambda: self.activation.__name__, str(self.activation)))
            return ret

    def regulization_loss(self, p=2):
        if self.hidden_dim is None:
            if p == 2:
                return self.weight.square().sum()
            if p == 1:
                return self.weight.abs().sum()
            return (self.weight.abs() ** p).sum()
        else:
            reg = []
            for w in self.weight:
                reg.append((w.abs() ** p).sum())
            return sum(reg)
コード例 #17
0
class group_relaxed_TF1Conv2d(Module):
	"""Implementation of TF1 regularization for the feature maps of a convolutional layer"""
	def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True,
		lamba=1., alpha=1., beta=4., weight_decay = 1., **kwargs):
		"""
		:param in_channels: Number of input channels
		:param out_channels: Number of output channels
		:param kernel_size: size of the kernel
		:param stride: stride for the convolution
		:param padding: padding for the convolution
		:param dilation: dilation factor for the convolution
		:param groups: how many groups we will assume in the convolution
		:param bias: whether we will use a bias
		:param lamba: strength of the TFL regularization
		"""
		super(group_relaxed_TF1Conv2d, self).__init__()
		self.floatTensor = torch.FloatTensor if not torch.cuda.is_available() else torch.cuda.FloatTensor
		self.in_channels = in_channels
		self.out_channels = out_channels
		self.kernel_size = pair(kernel_size)
		self.stride = pair(stride)
		self.padding = pair(padding)
		self.dilation = pair(dilation)
		self.output_padding = pair(0)
		self.groups = groups
		self.lamba = lamba
		self.alpha = alpha
		self.beta = beta
		self.lamba1 = self.lamba/self.beta
		self.weight_decay = weight_decay
		self.weight = Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size))
		self.u = torch.rand(out_channels, in_channels // groups, *self.kernel_size)
		self.u = self.u.to('cuda')
		if bias:
			self.bias = Parameter(torch.Tensor(out_channels))
		else:
			self.register_parameter('bias', None)
		self.reset_parameters()
		self.input_shape = None
		print(self)

	def reset_parameters(self):
		init.kaiming_normal(self.weight, mode='fan_in')
		

		if self.bias is not None:
			self.bias.data.normal_(0,1e-2)

	def phi(self,x):
		phi_x = torch.acos(1-27*(self.lamba1*self.alpha*(self.alpha+1))/(2*(self.alpha+x.abs())**3))
		return phi_x

	def g(self,x):
		g_x = x.sign()*(2/3*(self.alpha + x.abs())*torch.cos(self.phi(x)/3)-2*self.alpha/3+x.abs()/3)
		return g_x


	def constrain_parameters(self, thres_std=1.):
		#self.weight.data = F.normalize(self.weight.data, p=2, dim=1)
		#print(torch.sum(self.weight.pow(2)))
		if self.lamba1 <= (self.alpha**2)/(2*(self.alpha+1)):
			t = self.lamba1*(self.alpha+1)/(self.alpha)
		else:
			t = np.sqrt(2*self.lamba1*(self.alpha+1))-self.alpha/2
		self.u.data = self.weight.data.clone()
		self.u.data[self.u.data.abs() <=t] = 0
		g_result = self.g(self.u)
		self.u.data[self.u.data.abs() > t] = g_result[self.u.data.abs() > t]

	def grow_beta(self, growth_factor):
		self.beta = self.beta*growth_factor
		self.lamba1 = self.lamba/self.beta

	def _reg_w(self, **kwargs):
		logpw = -self.beta*torch.sum(0.5*self.weight.add(-self.u).pow(2))-self.lamba*np.sqrt(self.in_channels*self.kernel_size[0]*self.kernel_size[1])*torch.sum(torch.pow(torch.sum(self.weight.pow(2),3).sum(2).sum(1),0.5))
		logpb = 0
		if self.bias is not None:
			logpb = - torch.sum(self.weight_decay * .5 * (self.bias.pow(2)))
		return logpw+logpb

	def regularization(self):
		return self._reg_w()

	def count_zero_u(self):
		total = np.prod(self.u.size())
		zero = total - self.u.nonzero().size(0)
		return zero

	def count_zero_w(self):
		return torch.sum((self.weight.abs()<1e-5).int()).item()

	def count_active_neuron(self):
		return torch.sum((torch.sum(self.weight.abs(),3).sum(2).sum(1)/(self.in_channels*self.kernel_size[0]*self.kernel_size[1]))>1e-5).item()

	def count_total_neuron(self):
		return self.out_channels


	def count_weight(self):
		return np.prod(self.u.size())

	def count_expected_flops_and_l0(self):
		#ppos = self.out_channels
		ppos = torch.sum(torch.sum(self.weight.abs(),3).sum(2).sum(1)>0.001).item()
		n = self.kernel_size[0]*self.kernel_size[1]*self.in_channels
		flops_per_instance = n+(n-1)

		num_instances_per_filter = ((self.input_shape[1] -self.kernel_size[0]+2*self.padding[0])/self.stride[0]) + 1
		num_instances_per_filter *=((self.input_shape[2] - self.kernel_size[1]+2*self.padding[1])/self.stride[1]) + 1

		flops_per_filter = num_instances_per_filter * flops_per_instance
		expected_flops = flops_per_filter*ppos
		expected_l0 = n*ppos

		if self.bias is not None:
			expected_flops += num_instances_per_filter*ppos
			expected_l0 += ppos
		return expected_flops, expected_l0

	def forward(self, input_):
		if self.input_shape is None:
			self.input_shape = input_.size()
		output = F.conv2d(input_, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
		return output

	def __repr__(self):
		s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size} '
			', stride={stride}')
		if self.padding != (0,) * len(self.padding):
			s += ', padding={padding}'
		if self.dilation != (1,) * len(self.dilation):
			s += ', dilation={dilation}'
		if self.output_padding != (0,) * len(self.output_padding):
			s += ', output_padding={output_padding}'
		if self.groups != 1:
			s += ', groups={groups}'
		if self.bias is None:
			s += ', bias=False'
		s += ')'
		return s.format(name=self.__class__.__name__, **self.__dict__)
コード例 #18
0
class FilterStripe(nn.Conv2d):#卷积层+FS层
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1):
        super(FilterStripe, self).__init__(in_channels, out_channels, kernel_size, stride, kernel_size // 2, groups=1, bias=False)
        self.BrokenTarget = None
        self.FilterSkeleton = Parameter(torch.ones(self.out_channels, self.kernel_size[0], self.kernel_size[1]), requires_grad=True)#FS层初始化

    def forward(self, x):#forward()是自动调用的,x:[N,通道数,width,height]
        if self.BrokenTarget is not None:
            #out:[N,通道数,width,height]
            out = torch.zeros(x.shape[0], self.FilterSkeleton.shape[0], int(np.ceil(x.shape[2] / self.stride[0])), int(np.ceil(x.shape[3] / self.stride[1])))#ceil() 函数返回数字的上入整数
            if x.is_cuda:
                out = out.cuda()
            x = F.conv2d(x, self.weight)#卷积输出
            l, h = 0, 0
            for i in range(self.BrokenTarget.shape[0]):
                for j in range(self.BrokenTarget.shape[1]):
                    h += self.FilterSkeleton[:, i, j].sum().item()#FS层每个通道对应的值相加
                    out[:, self.FilterSkeleton[:, i, j]] += self.shift(x[:, l:h], i, j)[:, :, ::self.stride[0], ::self.stride[1]]#获得每个通道对应索引的输出
                    l += self.FilterSkeleton[:, i, j].sum().item()
            return out#输出
        else:
            #unsqueeze(1)在第二个维度增加一个维度
            return F.conv2d(x, self.weight * self.FilterSkeleton.unsqueeze(1), stride=self.stride, padding=self.padding, groups=self.groups)

    def prune_in(self, in_mask=None):#in_mask掩膜
        #self.weight.shape:[out_channel,k,k,in_channel]
        print(self.weight.shape)
        self.weight = Parameter(self.weight[:, in_mask])#??????????
        print(self.weight)
        self.in_channels = in_mask.sum().item()

    def prune_out(self, threshold):#threshold为阈值
        out_mask = (self.FilterSkeleton.abs() > threshold).sum(dim=(1, 2)) != 0#获得掩膜
        if out_mask.sum() == 0:
            print(out_mask.sum())
            out_mask[0] = True
        self.weight = Parameter(self.weight[out_mask])#卷积核掩膜化
        self.FilterSkeleton = Parameter(self.FilterSkeleton[out_mask], requires_grad=True)#FS层掩膜化
        self.out_channels = out_mask.sum().item()#获取输出通道
        return out_mask#掩膜

    def _break(self, threshold):
        self.weight = Parameter(self.weight * self.FilterSkeleton.unsqueeze(1))#卷积核与FS层相乘
        self.FilterSkeleton = Parameter((self.FilterSkeleton.abs() > threshold), requires_grad=False)#FS层大于阈值的为true
        if self.FilterSkeleton.sum() == 0:
            self.FilterSkeleton.data[0][0][0] = True
        self.out_channels = self.FilterSkeleton.sum().item()
        self.BrokenTarget = self.FilterSkeleton.sum(dim=0)
        self.kernel_size = (1, 1)
        #permute()将tensor的维度换位。
        # print(self.FilterSkeleton.permute(1, 2, 0).reshape(-1))
        # print(self.weight.permute(2, 3, 0, 1).reshape(-1, self.in_channels, 1, 1))
        self.weight = Parameter(self.weight.permute(2, 3, 0, 1).reshape(-1, self.in_channels, 1, 1)[self.FilterSkeleton.permute(1, 2, 0).reshape(-1)])#掩膜化
        # print(self.weight)

    def update_skeleton(self, sr, threshold):
        self.FilterSkeleton.grad.data.add_(sr * torch.sign(self.FilterSkeleton.data))#FS层的梯度更新,加入L1范数的导数
        mask = self.FilterSkeleton.data.abs() > threshold
        self.FilterSkeleton.data.mul_(mask)#掩码化
        self.FilterSkeleton.grad.data.mul_(mask)#掩码化
        out_mask = mask.sum(dim=(1, 2)) != 0#????
        return out_mask

    def shift(self, x, i, j):
        return F.pad(x, (self.BrokenTarget.shape[0] // 2 - j, j - self.BrokenTarget.shape[0] // 2, self.BrokenTarget.shape[0] // 2 - i, i - self.BrokenTarget.shape[1] // 2), 'constant', 0)

    def extra_repr(self):
        s = ('{BrokenTarget},{in_channels}, {out_channels}, kernel_size={kernel_size}'
             ', stride={stride}')
        return s.format(**self.__dict__)
コード例 #19
0
class group_relaxed_L1L2Dense(Module):
    """Implementation of TFL regularization for the input units of a fully connected layer"""
    def __init__(self,
                 in_features,
                 out_features,
                 bias=True,
                 lamba=1.,
                 alpha=1.,
                 beta=4.,
                 weight_decay=1.,
                 **kwargs):
        """
		:param in_features: input dimensionality
		:param out_features: output dimensionality
		:param bias: whether we use bias
		:param lamba: strength of the TF1 regularization
		"""
        super(group_relaxed_L1L2Dense, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(in_features, out_features))
        self.u = torch.rand(in_features, out_features)
        self.u = self.u.to('cuda')
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.lamba = lamba
        self.alpha = alpha
        self.beta = beta
        self.lamba1 = self.lamba / self.beta
        self.weight_decay = weight_decay
        self.floatTensor = torch.FloatTensor if not torch.cuda.is_available(
        ) else torch.cuda.FloatTensor
        self.reset_parameters()
        print(self)

    def reset_parameters(self):
        init.kaiming_normal(self.weight, mode='fan_out')

        if self.bias is not None:
            self.bias.data.normal_(0, 1e-2)

    def constrain_parameters(self, **kwargs):
        norm_w = self.weight.data.norm(p=float('inf'))
        if norm_w > self.lamba1:
            m = Softshrink(self.lamba1)
            z = m(self.weight.data)
            self.u.data = z * (z.data.norm(p=2) +
                               self.alpha * self.lamba1) / (z.data.norm(p=2))
        elif norm_w == self.lamba1:
            self.u = self.weight.clone()
            self.u[self.u.abs() < lamba1] = 0
            n = torch.sum(self.u != 0)
            self.u[self.u != 0] = self.weight.sign(
            ) * self.alpha * self.lamba1 / (n**(1 / 2))

        elif (1 - self.alpha) * self.lamba1 < norm_w and norm_w < self.lamba1:
            self.u = self.weight.clone()
            max_idx = np.unravel_index(torch.argmax(self.u.cpu(), None),
                                       self.u.shape)
            max_value_sign = self.u[max_idx].sign()
            self.u[:] = 0
            self.u[max_idx] = (norm_w +
                               (self.alpha - 1) * self.lamba1) * max_value_sign
        else:
            self.u = self.weight.clone()
            self.u[:] = 0

    def grow_beta(self, growth_factor):
        self.beta = self.beta * growth_factor
        self.lamba1 = self.lamba / self.beta

    def _reg_w(self, **kwargs):
        logpw = -self.beta * torch.sum(
            0.5 * self.weight.add(-self.u).pow(2)) - self.lamba * np.sqrt(
                self.out_features) * torch.sum(
                    torch.pow(torch.sum(self.weight.pow(2), 1), 0.5))
        logpb = 0
        if self.bias is not None:
            logpb = -torch.sum(self.weight_decay * .5 * (self.bias.pow(2)))
        return logpw + logpb

    def regularization(self):
        return self._reg_w()

    def count_zero_u(self):
        total = np.prod(self.u.size())
        zero = total - self.u.nonzero().size(0)
        return zero

    def count_zero_w(self):
        return torch.sum((self.weight.abs() < 1e-5).int()).item()

    def count_weight(self):
        return np.prod(self.u.size())

    def count_active_neuron(self):
        return torch.sum(
            torch.sum(self.weight.abs() / self.out_features, 1) > 1e-5).item()

    def count_total_neuron(self):
        return self.in_features

    def count_expected_flops_and_l0(self):
        ppos = torch.sum(self.weight.abs() > 0.000001).item()
        expected_flops = (2 * ppos - 1) * self.out_features
        expected_l0 = ppos * self.out_features
        if self.bias is not None:
            expected_flops += self.out_features
            expected_l0 += self.out_features
        return expected_flops, expected_l0

    def forward(self, input):
        output = input.mm(self.weight)
        if self.bias is not None:
            output.add_(self.bias.view(1, self.out_features).expand_as(output))
        return output

    def __repr__(self):
        return self.__class__.__name__+' (' \
         + str(self.in_features) + ' -> ' \
         + str(self.out_features) + ', lambda: ' \
         + str(self.lamba) + ')'
コード例 #20
0
ファイル: layers.py プロジェクト: timko98/hydra
class SubnetLinear(nn.Linear):
    # self.k is the % of weights remaining, a real number in [0,1]
    # self.popup_scores is a Parameter which has the same shape as self.weight
    # Gradients to self.weight, self.bias have been turned off.

    def __init__(self, in_features, out_features, bias=True):
        super(SubnetLinear, self).__init__(in_features,
                                           out_features,
                                           bias=True)
        # Weight pruning
        # self.popup_scores = Parameter(torch.Tensor(self.weight.shape))
        # Channel Finetuning or Resume Pruning
        # self.popup_scores = Parameter(torch.Tensor(torch.Size([1,self.weight.shape[1]])))
        # Channel Pruning
        self.popup_scores = Parameter(
            torch.Tensor(torch.Size([self.weight.shape[0], 1])))

        nn.init.kaiming_uniform_(self.popup_scores, a=math.sqrt(5))
        self.weight.requires_grad = False
        self.bias.requires_grad = False
        self.w = 0
        # self.register_buffer('w', None)

    def set_prune_rate(self, k):
        self.k = k

    def forward(self, x):
        """ Unstructured Comparison
        remaining_weights = int(self.k * len(self.weight.flatten()))
        idx_same_top_weights_scores = list(
            set(torch.topk(self.weight.abs().flatten(), remaining_weights).indices.tolist()).intersection(
                set(torch.topk(self.popup_scores.abs().flatten(), remaining_weights).indices.tolist())))
        num_remaining_weights = len(idx_same_top_weights_scores)
        print(
            f"SubnetLinear: Number of same indices for scores and weights that are left after pruning: "
            f"{num_remaining_weights}. These are {float(num_remaining_weights / remaining_weights)} percent of the "
            f"weights kept.")
        """
        """ Structured Comparison
        remaining_filters = int(self.k * self.weight.shape[0])
        idx_same_top_weights_scores = list(set(
            torch.topk(torch.linalg.norm(self.weight.abs().reshape(self.weight.shape[0], -1), 1, dim=1),
                       remaining_filters).indices.tolist()).intersection(
            torch.topk(torch.linalg.norm(self.popup_scores.abs().reshape(self.popup_scores.shape[0], -1), 1, dim=1),
                       remaining_filters).indices.tolist()))
        num_remaining_filters = len(idx_same_top_weights_scores)
        print(
            f"SubnetLinear: Number of same indices for filters that are left after pruning using scores or weights : "
            f"{num_remaining_filters}. These are {float(num_remaining_filters / remaining_filters)} percent of the "
            f"filters kept.")
        """
        """ Channel Prune VGG16
        global linear_nr
        if linear_nr == 3:
            linear_nr = 1
        else:
            linear_nr += 1
        # Get the subnetwork by sorting the scores.
        mask_linear_50 = [0.10107422, 0.1015625, 0.1015625]
        mask_linear_10 = [0.016601562, 0.015625, 0.015625]
        k = mask_linear_10[linear_nr-1]
        adj = GetSubnet.apply(self.popup_scores.abs(), self.k)
        """
        # Fixed mask WRN Channel Prune
        # adj = GetSubnet.apply(self.popup_scores.abs(), 0.44140625)
        adj = GetSubnet.apply(self.popup_scores.abs(), self.k)

        # Use only the subnetwork in the forward pass.
        self.w = self.weight * adj
        x = F.linear(x, self.w, self.bias)

        return x
コード例 #21
0
class group_relaxed_L1L2Conv2d(Module):
    """Implementation of TF1 regularization for the feature maps of a convolutional layer"""
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
                 lamba=1.,
                 alpha=1.,
                 beta=4.,
                 weight_decay=1.,
                 **kwargs):
        """
		:param in_channels: Number of input channels
		:param out_channels: Number of output channels
		:param kernel_size: size of the kernel
		:param stride: stride for the convolution
		:param padding: padding for the convolution
		:param dilation: dilation factor for the convolution
		:param groups: how many groups we will assume in the convolution
		:param bias: whether we will use a bias
		:param lamba: strength of the TFL regularization
		"""
        super(group_relaxed_L1L2Conv2d, self).__init__()
        self.floatTensor = torch.FloatTensor if not torch.cuda.is_available(
        ) else torch.cuda.FloatTensor
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = pair(kernel_size)
        self.stride = pair(stride)
        self.padding = pair(padding)
        self.dilation = pair(dilation)
        self.output_padding = pair(0)
        self.groups = groups
        self.lamba = lamba
        self.alpha = alpha
        self.beta = beta
        self.lamba1 = self.lamba / self.beta
        self.weight_decay = weight_decay
        self.weight = Parameter(
            torch.Tensor(out_channels, in_channels // groups,
                         *self.kernel_size))
        self.u = torch.rand(out_channels, in_channels // groups,
                            *self.kernel_size)
        self.u = self.u.to('cuda')
        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()
        self.input_shape = None
        print(self)

    def reset_parameters(self):
        init.kaiming_normal(self.weight, mode='fan_in')

        if self.bias is not None:
            self.bias.data.normal_(0, 1e-2)

    def constrain_parameters(self, **kwargs):
        norm_w = self.weight.data.norm(p=float('inf'))
        if norm_w > self.lamba1:
            m = Softshrink(self.lamba1)
            z = m(self.weight.data)
            self.u.data = z * (z.data.norm(p=2) +
                               self.alpha * self.lamba1) / (z.data.norm(p=2))
        elif norm_w == self.lamba1:
            self.u = self.weight.clone()
            self.u[self.u.abs() < lamba1] = 0
            n = torch.sum(self.u != 0)
            self.u[self.u != 0] = self.weight.sign(
            ) * self.alpha * self.lamba1 / (n**(1 / 2))

        elif (1 - self.alpha) * self.lamba1 < norm_w and norm_w < self.lamba1:
            self.u = self.weight.clone()
            max_idx = np.unravel_index(torch.argmax(self.u.cpu(), None),
                                       self.u.shape)
            max_value_sign = self.u[max_idx].sign()
            self.u[:] = 0
            self.u[max_idx] = (norm_w +
                               (self.alpha - 1) * self.lamba1) * max_value_sign
        else:
            self.u = self.weight.clone()
            self.u[:] = 0

    def grow_beta(self, growth_factor):
        self.beta = self.beta * growth_factor
        self.lamba1 = self.lamba / self.beta

    def _reg_w(self, **kwargs):
        logpw = -self.beta * torch.sum(
            0.5 * self.weight.add(-self.u).pow(2)) - self.lamba * np.sqrt(
                self.in_channels * self.kernel_size[0] * self.kernel_size[1]
            ) * torch.sum(
                torch.pow(torch.sum(self.weight.pow(2), 3).sum(2).sum(1), 0.5))
        logpb = 0
        if self.bias is not None:
            logpb = -torch.sum(self.weight_decay * .5 * (self.bias.pow(2)))
        return logpw + logpb

    def regularization(self):
        return self._reg_w()

    def count_zero_u(self):
        total = np.prod(self.u.size())
        zero = total - self.u.nonzero().size(0)
        return zero

    def count_zero_w(self):
        return torch.sum((self.weight.abs() < 1e-5).int()).item()

    def count_active_neuron(self):
        return torch.sum((torch.sum(self.weight.abs(), 3).sum(2).sum(1) /
                          (self.in_channels * self.kernel_size[0] *
                           self.kernel_size[1])) > 1e-5).item()

    def count_total_neuron(self):
        return self.out_channels

    def count_weight(self):
        return np.prod(self.u.size())

    def count_expected_flops_and_l0(self):
        #ppos = self.out_channels
        ppos = torch.sum(
            torch.sum(self.weight.abs(), 3).sum(2).sum(1) > 0.001).item()
        n = self.kernel_size[0] * self.kernel_size[1] * self.in_channels
        flops_per_instance = n + (n - 1)

        num_instances_per_filter = (
            (self.input_shape[1] - self.kernel_size[0] + 2 * self.padding[0]) /
            self.stride[0]) + 1
        num_instances_per_filter *= (
            (self.input_shape[2] - self.kernel_size[1] + 2 * self.padding[1]) /
            self.stride[1]) + 1

        flops_per_filter = num_instances_per_filter * flops_per_instance
        expected_flops = flops_per_filter * ppos
        expected_l0 = n * ppos

        if self.bias is not None:
            expected_flops += num_instances_per_filter * ppos
            expected_l0 += ppos
        return expected_flops, expected_l0

    def forward(self, input_):
        if self.input_shape is None:
            self.input_shape = input_.size()
        output = F.conv2d(input_, self.weight, self.bias, self.stride,
                          self.padding, self.dilation, self.groups)
        return output

    def __repr__(self):
        s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size} '
             ', stride={stride}')
        if self.padding != (0, ) * len(self.padding):
            s += ', padding={padding}'
        if self.dilation != (1, ) * len(self.dilation):
            s += ', dilation={dilation}'
        if self.output_padding != (0, ) * len(self.output_padding):
            s += ', output_padding={output_padding}'
        if self.groups != 1:
            s += ', groups={groups}'
        if self.bias is None:
            s += ', bias=False'
        s += ')'
        return s.format(name=self.__class__.__name__, **self.__dict__)
コード例 #22
0
class lq_conv2d_orig(nn.Conv2d):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=False,
                 padding_mode='zeros',
                 is_qt=False,
                 tr_gamma=True,
                 lq=False,
                 block_num=-1,
                 layer_num=-1,
                 index=[],
                 fwlq=False):
        super(lq_conv2d_orig,
              self).__init__(in_channels, out_channels, kernel_size, stride,
                             padding, dilation, groups, bias, padding_mode)
        self.block_num = block_num
        self.layer_num = layer_num
        self.index = index
        self.w_shape = self.weight.shape

        self.is_qt = is_qt
        self.lq = lq
        self.fwlq = fwlq

        #if groups != 1:
        #    self.lq = False

        if lq:
            if fwlq:
                print("filter-wise learning to quantize")
                self.cw = Parameter(torch.ones([out_channels, 1]))
                self.dw = Parameter(torch.ones([out_channels, 1]))
                self.gamma = Parameter(torch.ones([out_channels, 1
                                                   ])) if tr_gamma else 1
            else:
                self.cw = Parameter(torch.Tensor([1]))
                self.dw = Parameter(torch.Tensor([1]))
                self.gamma = Parameter(torch.Tensor([1])) if tr_gamma else 1
            self.cx = Parameter(torch.Tensor([2]))
            self.dx = Parameter(torch.Tensor([2]))
            self.tr_gamma = tr_gamma

    def set_bit_width(self, w_bit, x_bit, initskip):
        self.w_bit = w_bit
        self.x_bit = x_bit
        if isinstance(x_bit, list):
            self.qx = [2**(bit) - 1 for bit in x_bit]
            self.theta_x = Parameter(torch.ones([len(x_bit)] / len(x_bit)))
        else:
            self.qx = 2**(x_bit) - 1
        if isinstance(w_bit, list):
            self.qw = [2**(bit - 1) - 1 for bit in w_bit]
            self.theta_w = Parameter(torch.ones([len(w_bit)]) / len(w_bit))
        else:
            self.qw = 2**(w_bit - 1) - 1

        # Read filterwise bitwidth index
        if self.index != []:
            self.qw = torch.ones((self.w_shape[0], 1))
            bit_max = 9
            for i in range(bit_max):
                if len(self.index[self.block_num][self.layer_num][i]) == 0:
                    continue
                else:
                    idx = self.index[self.block_num][self.layer_num][i]
                    self.qw[idx] = 2**(i + 1) - 1
            self.qw = self.qw.cuda()

        # Initialize c, d
        if self.lq and not initskip:
            with torch.no_grad():
                if self.fwlq:
                    self.cw *= self.weight.abs().mean(
                    )  #(dim=[1,2,3]).view((-1,1))
                    self.dw *= self.weight.std()  #(dim=[1,2,3]).view((-1,1))
                else:
                    self.cw *= self.weight.abs().mean()
                    self.dw *= self.weight.std()

    def bitops_count(self, soft_mask_w=None, soft_mask_x=None):
        x_shape = torch.Tensor([self.x_shape])
        w_shape = torch.Tensor([self.weight.shape])
        flops = x_shape.prod()
        flops *= w_shape.prod()
        # case 1: soft mask w, one value x
        if soft_mask_x == None and soft_mask_w != None:
            bitops = torch.Tensor(self.w_bit).cuda() * soft_mask_w * self.x_bit
            bitops *= flops
            bitops = bitops.sum()
        # case 0: 32-bit w and x
        elif not (soft_mask_x or soft_mask_w):
            bitops = torch.Tensor([32 * 32]).cuda()
            bitops *= flops
        return bitops

    def forward(self, input):
        self.x_shape = input.shape[2:]
        soft_mask_w = None
        if self.lq:
            w_abs = self.weight.abs()
            w_sign = self.weight.sign()

            w_abs = w_abs.view(self.w_shape[0], -1)
            w_sign = w_sign.view(self.w_shape[0], -1)

            eps = 1e-7
            _dw = self.dw.abs() + eps  #.abs()
            _dx = self.dx  #s.abs()

            # yejun: d, gamma (no c)
            # Transformer_W
            w_mask1 = (w_abs <= _dw).type(torch.float).detach()
            w_mask2 = (w_abs > _dw).type(torch.float).detach()

            w_cal = w_abs / _dw
            nan_detect(w_cal)
            nan_detect(w_cal.pow(self.gamma))
            w_hat = (w_mask2 * w_sign) + (w_mask1 *
                                          (w_cal).pow(self.gamma) * w_sign)
            nan_detect(w_hat)

            # Discretizer_W
            if isinstance(self.qw, list):
                # 1. learning bitwidth
                w_bar_list = []
                for qw in self.qw:
                    w_bar = Round.apply(w_hat * qw) / qw
                    nan_detect(w_bar)
                    w_bar_list.append(w_bar)
                soft_mask_w = nn.functional.gumbel_softmax(self.theta_w,
                                                           tau=1,
                                                           hard=False)
                w_bar = sum(w * theta
                            for w, theta in zip(w_bar_list, soft_mask_w))
                w_bar = w_bar.view(self.w_shape)
            else:
                # 2. fixed bitwidth
                w_bar = Round.apply(w_hat * self.qw) / self.qw
                nan_detect(w_bar)
                w_bar = w_bar.view(self.w_shape)
                nan_detect(w_bar)

            # Transformer X
            x_mask1 = (input <= self.dx).type(torch.float).detach()
            x_mask2 = (input > self.dx).type(torch.float).detach()
            x_cal = input / self.dx
            nan_detect(x_cal)
            x_hat = x_mask1 * x_cal + x_mask2
            nan_detect(x_hat)

            # Discretizer X
            if isinstance(self.qx, list):
                # 1. learning bitwidth
                x_bar_list = []
                for qx in self.qx:
                    x_bar = Round.apply(x_hat * qx) / qx
                    nan_detect(x_bar)
                    x_bar_list.append(x_bar)
                soft_mask_x = nn.functional.gumbel_softmax(self.theta_x,
                                                           tau=1,
                                                           hard=False)
                x_bar = sum(x * theta
                            for x, theta in zip(x_bar_list, soft_mask_x))
            else:
                # 2. fixed bitwidth
                x_bar = Round.apply(x_hat * self.qx) / self.qx
                nan_detect(x_bar)
            y = F.conv2d(x_bar, w_bar, self.bias, self.stride, self.padding,
                         self.dilation, self.groups)

        elif self.is_qt:
            if isinstance(self.qw, list):
                w_list = []
                for qw in self.qw:
                    w_list.append(quantize(self.weight, num_bits=qw))
                soft_mask_w = nn.functional.gumbel_softmax(self.theta_w,
                                                           tau=1,
                                                           hard=False)
                w = sum(w_ * theta for w_, theta in zip(w_list, soft_mask_w))
                #w= w_bar.view(self.w_shape)

            else:
                w = quantize(self.weight,
                             num_bits=self.w_bit,
                             block_num=self.block_num,
                             layer_num=self.layer_num,
                             multi=True,
                             index=self.index)
            x = quantize(input, num_bits=self.x_bit, is_act=True)
            y = F.conv2d(x, w, self.bias, self.stride, self.padding,
                         self.dilation, self.groups)
        else:
            y = F.conv2d(input, self.weight, self.bias, self.stride,
                         self.padding, self.dilation, self.groups)
        flops = self.bitops_count(soft_mask_w=soft_mask_w, soft_mask_x=None)
        return y, flops
コード例 #23
0
class AttentionReadout(Readout):
    def __init__(
        self,
        in_shape: Tuple[int, int, int],
        outdims: int,
        bias: bool,
        init_noise: float = 1e-3,
        attention_kernel: int = 1,
        attention_layers: int = 1,
        mean_activity: Optional[Mapping[str, float]] = None,
        feature_reg_weight: float = 1.0,
        gamma_readout: Optional[
            float] = None,  # deprecated, use feature_reg_weight instead
        **kwargs: Any,
    ) -> None:
        super().__init__()
        self.in_shape = in_shape
        self.outdims = outdims
        self.feature_reg_weight = self.resolve_deprecated_gamma_readout(
            feature_reg_weight, gamma_readout)  # type: ignore[no-untyped-call]
        self.mean_activity = mean_activity
        c, w, h = in_shape
        self.features = Parameter(torch.Tensor(self.outdims, c))

        attention = Sequential()
        for i in range(attention_layers - 1):
            attention.add_module(
                f"conv{i}",
                Conv2d(c, c, attention_kernel, padding=attention_kernel > 1),
            )
            attention.add_module(
                f"norm{i}", BatchNorm2d(c))  # type: ignore[no-untyped-call]
            attention.add_module(f"nonlin{i}", ELU())
        else:
            attention.add_module(
                f"conv{attention_layers}",
                Conv2d(c,
                       outdims,
                       attention_kernel,
                       padding=attention_kernel > 1),
            )
        self.attention = attention

        self.init_noise = init_noise
        if bias:
            bias_param = Parameter(torch.Tensor(self.outdims))
            self.register_parameter("bias", bias_param)
        else:
            self.register_parameter("bias", None)
        self.initialize(mean_activity)

    @staticmethod
    def init_conv(m: Module) -> None:
        if isinstance(m, Conv2d):
            init.xavier_normal_(m.weight.data)
            if m.bias is not None:
                m.bias.data.fill_(0)

    def initialize_attention(self) -> None:
        self.apply(self.init_conv)

    def initialize(
        self,
        mean_activity: Optional[Mapping[str, float]] = None
    ) -> None:  # type: ignore[override]
        if mean_activity is None:
            mean_activity = self.mean_activity
        self.features.data.normal_(0, self.init_noise)
        if self.bias is not None:
            self.initialize_bias(
                mean_activity=mean_activity)  # type: ignore[no-untyped-call]
        self.initialize_attention()

    def feature_l1(self,
                   reduction: Literal["sum", "mean", None] = "sum",
                   average: Optional[bool] = None) -> torch.Tensor:
        return self.apply_reduction(
            self.features.abs(), reduction=reduction,
            average=average)  # type: ignore[no-untyped-call,no-any-return]

    def regularizer(self,
                    reduction: Literal["sum", "mean", None] = "sum",
                    average: Optional[bool] = None) -> torch.Tensor:
        return self.feature_l1(
            reduction=reduction, average=average
        ) * self.feature_reg_weight  # type: ignore[no-any-return]

    def forward(self,
                x: torch.Tensor,
                shift: Optional[Any] = None) -> torch.Tensor:
        attention = self.attention(x)
        b, c, w, h = attention.shape
        attention = F.softmax(attention.view(b, c, -1),
                              dim=-1).view(b, c, w, h)
        y: torch.Tensor = torch.einsum("bnwh,bcwh->bcn", attention,
                                       x)  # type: ignore[attr-defined]
        y = torch.einsum("bcn,nc->bn", y,
                         self.features)  # type: ignore[attr-defined]
        if self.bias is not None:
            y = y + self.bias
        return y

    def __repr__(self) -> str:
        return self.__class__.__name__ + " (" + "{} x {} x {}".format(
            *self.in_shape) + " -> " + str(self.outdims) + ")"
コード例 #24
0
class group_relaxed_SCAD_Dense(Module):
	"""Implementation of TFL regularization for the input units of a fully connected layer"""
	def __init__(self, in_features, out_features, bias=True, lamba=1., alpha = 3.7, beta = 4.0, weight_decay=1., **kwargs):
		"""
		:param in_features: input dimensionality
		:param out_features: output dimensionality
		:param bias: whether we use bias
		:param lamba: strength of the TF1 regularization
		"""
		super(group_relaxed_SCAD_Dense,self).__init__()
		self.in_features = in_features
		self.out_features = out_features
		self.weight = Parameter(torch.Tensor(in_features, out_features))
		self.u = torch.rand(in_features, out_features)
		self.u = self.u.to('cuda')
		if bias:
			self.bias = Parameter(torch.Tensor(out_features))
		else:
			self.register_parameter('bias', None)
		self.lamba = lamba
		self.alpha = alpha
		self.beta = beta
		self.lamba1 = self.lamba/self.beta
		self.weight_decay = weight_decay
		self.floatTensor = torch.FloatTensor if not torch.cuda.is_available() else torch.cuda.FloatTensor
		self.reset_parameters()
		print(self)

	def reset_parameters(self):
		init.kaiming_normal(self.weight, mode='fan_out')

		if self.bias is not None:
			self.bias.data.normal_(0,1e-2)


	def constrain_parameters(self, **kwargs):
		self.u = self.weight.clone()
		s = Softshrink(self.lamba1)
		#shrinkage on values with absolute value less than 2*lamba1
		shrink_value = s(self.weight.data)
		self.u[self.weight.abs()<=2*self.lamba1] = shrink_value[self.weight.abs()<=2*self.lamba1]

		#modify values whose absolute values are between 2*lamba1 and alpha*lamba1
		modify_weight = self.weight.data
		modify_weight = ((self.alpha - 1)*modify_weight-modify_weight.sign()*(3.7*self.lamba1))/(self.alpha -2)
		self.u[(self.weight.abs()>2*self.lamba1) & (self.weight.abs()<=self.alpha*self.lamba1)] = modify_weight[(self.weight.abs()>2*self.lamba1) & (self.weight.abs()<=self.alpha*self.lamba1)]


	def grow_beta(self, growth_factor):
		self.beta = self.beta*growth_factor
		self.lamba1 = self.lamba/self.beta

	def _reg_w(self, **kwargs):
		logpw = -self.beta*torch.sum(0.5*self.weight.add(-self.u).pow(2))-self.lamba*np.sqrt(self.out_features)*torch.sum(torch.pow(torch.sum(self.weight.pow(2),1),0.5))
		logpb = 0
		if self.bias is not None:
			logpb = - torch.sum(self.weight_decay * .5 * (self.bias.pow(2)))
		return logpw + logpb

	def regularization(self):
		return self._reg_w()

	def count_zero_u(self):
		total = np.prod(self.u.size())
		zero = total - self.u.nonzero().size(0)
		return zero

	def count_zero_w(self):
		return torch.sum((self.weight.abs()<1e-5).int()).item()

	def count_weight(self):
		return np.prod(self.u.size())

	def count_active_neuron(self):
		return torch.sum(torch.sum(self.weight.abs()/self.out_features,1)>1e-5).item()

	def count_total_neuron(self):
		return self.in_features

	def count_expected_flops_and_l0(self):
		ppos = torch.sum(self.weight.abs()>0.000001).item()
		expected_flops = (2*ppos-1)*self.out_features
		expected_l0 = ppos*self.out_features
		if self.bias is not None:
			expected_flops += self.out_features
			expected_l0 += self.out_features
		return expected_flops, expected_l0

	def forward(self, input):
		output = input.mm(self.weight)
		if self.bias is not None:
			output.add_(self.bias.view(1, self.out_features).expand_as(output))
		return output

	def __repr__(self):
		return self.__class__.__name__+' (' \
			+ str(self.in_features) + ' -> ' \
			+ str(self.out_features) + ', lambda: ' \
			+ str(self.lamba) + ')'
コード例 #25
0
class sparse_group_lasso_Conv2d(Module):
    """Implementation of TF1 regularization for the feature maps of a convolutional layer"""
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
                 lamba=1.,
                 weight_decay=1.,
                 **kwargs):
        """
		:param in_channels: Number of input channels
		:param out_channels: Number of output channels
		:param kernel_size: size of the kernel
		:param stride: stride for the convolution
		:param padding: padding for the convolution
		:param dilation: dilation factor for the convolution
		:param groups: how many groups we will assume in the convolution
		:param bias: whether we will use a bias
		:param lamba: strength of the TFL regularization
		"""
        super(sparse_group_lasso_Conv2d, self).__init__()
        self.floatTensor = torch.FloatTensor if not torch.cuda.is_available(
        ) else torch.cuda.FloatTensor
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = pair(kernel_size)
        self.stride = pair(stride)
        self.padding = pair(padding)
        self.dilation = pair(dilation)
        self.output_padding = pair(0)
        self.groups = groups
        self.lamba = lamba
        self.weight_decay = weight_decay
        self.weight = Parameter(
            torch.Tensor(out_channels, in_channels // groups,
                         *self.kernel_size))
        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()
        self.input_shape = None
        print(self)

    def reset_parameters(self):
        init.kaiming_normal(self.weight, mode='fan_in')

        if self.bias is not None:
            self.bias.data.normal_(0, 1e-2)

    def constrain_parameters(self, thres_std=1.):
        pass

    def _reg_w(self, **kwargs):
        logpw = -self.lamba * np.sqrt(
            self.in_channels * self.kernel_size[0] *
            self.kernel_size[1]) * torch.sum(
                torch.pow(torch.sum(self.weight.pow(2), 3).sum(2).sum(1),
                          0.5)) - torch.sum(self.lamba * self.weight.abs())
        logpb = 0
        if self.bias is not None:
            logpb = -torch.sum(self.weight_decay * .5 * (self.bias.pow(2)))
        return logpw + logpb

    def regularization(self):
        return self._reg_w()

    def count_zero_w(self):
        return torch.sum((self.weight.abs() < 1e-5).int()).item()

    def count_weight(self):
        return np.prod(self.weight.size())

    def count_active_neuron(self):
        return torch.sum((torch.sum(self.weight.abs(), 3).sum(2).sum(1) /
                          (self.in_channels * self.kernel_size[0] *
                           self.kernel_size[1])) > 1e-5).item()

    def count_total_neuron(self):
        return self.out_channels

    def count_expected_flops_and_l0(self):
        #ppos = self.out_channels
        ppos = torch.sum(
            torch.sum(self.weight.abs(), 3).sum(2).sum(1) > 0.001).item()
        n = self.kernel_size[0] * self.kernel_size[1] * self.in_channels
        flops_per_instance = n + (n - 1)

        num_instances_per_filter = (
            (self.input_shape[1] - self.kernel_size[0] + 2 * self.padding[0]) /
            self.stride[0]) + 1
        num_instances_per_filter *= (
            (self.input_shape[2] - self.kernel_size[1] + 2 * self.padding[1]) /
            self.stride[1]) + 1

        flops_per_filter = num_instances_per_filter * flops_per_instance
        expected_flops = flops_per_filter * ppos
        expected_l0 = n * ppos

        if self.bias is not None:
            expected_flops += num_instances_per_filter * ppos
            expected_l0 += ppos
        return expected_flops, expected_l0

    def forward(self, input_):
        if self.input_shape is None:
            self.input_shape = input_.size()
        output = F.conv2d(input_, self.weight, self.bias, self.stride,
                          self.padding, self.dilation, self.groups)
        return output

    def __repr__(self):
        s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size} '
             ', stride={stride}')
        if self.padding != (0, ) * len(self.padding):
            s += ', padding={padding}'
        if self.dilation != (1, ) * len(self.dilation):
            s += ', dilation={dilation}'
        if self.output_padding != (0, ) * len(self.output_padding):
            s += ', output_padding={output_padding}'
        if self.groups != 1:
            s += ', groups={groups}'
        if self.bias is None:
            s += ', bias=False'
        s += ')'
        return s.format(name=self.__class__.__name__, **self.__dict__)
コード例 #26
0
class CGES_Dense(Module):
    """Implementation of TFL regularization for the input units of a fully connected layer"""
    def __init__(self,
                 in_features,
                 out_features,
                 bias=True,
                 lamba=1.,
                 weight_decay=1.,
                 mu=0.5,
                 **kwargs):
        """
		:param in_features: input dimensionality
		:param out_features: output dimensionality
		:param bias: whether we use bias
		:param lamba: strength of the TF1 regularization
		"""
        super(CGES_Dense, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(in_features, out_features))
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.lamba = lamba
        self.weight_decay = weight_decay
        self.floatTensor = torch.FloatTensor if not torch.cuda.is_available(
        ) else torch.cuda.FloatTensor
        self.mu = mu
        self.reset_parameters()
        print(self)

    def reset_parameters(self):
        init.kaiming_normal(self.weight, mode='fan_out')

        if self.bias is not None:
            self.bias.data.normal_(0, 1e-2)

    def constrain_parameters(self, **kwargs):
        pass

    def _reg_w(self, **kwargs):
        logpw = -(1 - self.mu) * self.lamba * torch.sum(
            torch.pow(torch.sum(self.weight.pow(2), 1),
                      0.5)) - self.mu * self.lamba / 2 * torch.sum(
                          torch.sum(self.weight.abs(), 1)**2)
        logpb = 0
        if self.bias is not None:
            logpb = -torch.sum(self.weight_decay * .5 * (self.bias.pow(2)))
        return logpw + logpb

    def set_mu(self, mu):
        self.mu = mu

    def regularization(self):
        return self._reg_w()

    def count_zero_w(self):
        return torch.sum((self.weight.abs() < 1e-5).int()).item()

    def count_weight(self):
        return np.prod(self.weight.size())

    def count_active_neuron(self):
        return torch.sum(
            torch.sum(self.weight.abs() / self.out_features, 1) > 1e-5).item()

    def count_total_neuron(self):
        return self.in_features

    def count_expected_flops_and_l0(self):
        ppos = torch.sum(self.weight.abs() > 0.000001).item()
        expected_flops = (2 * ppos - 1) * self.out_features
        expected_l0 = ppos * self.out_features
        if self.bias is not None:
            expected_flops += self.out_features
            expected_l0 += self.out_features
        return expected_flops, expected_l0

    def forward(self, input):
        output = input.mm(self.weight)
        if self.bias is not None:
            output.add_(self.bias.view(1, self.out_features).expand_as(output))
        return output

    def __repr__(self):
        return self.__class__.__name__+' (' \
         + str(self.in_features) + ' -> ' \
         + str(self.out_features) + ', lambda: ' \
         + str(self.lamba) + ')'