class MAPDense(Module):
    def __init__(self,
                 in_features,
                 out_features,
                 bias=True,
                 weight_decay=1.,
                 **kwargs):
        super(MAPDense, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(in_features, out_features))
        self.weight_decay = weight_decay
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.floatTensor = torch.FloatTensor if not torch.cuda.is_available(
        ) else torch.cuda.FloatTensor
        self.reset_parameters()
        print(self)

    def reset_parameters(self):
        init.kaiming_normal_(self.weight, mode='fan_out')

        if self.bias is not None:
            self.bias.data.normal_(0, 1e-2)

    def constrain_parameters(self, thres_std=1.):
        pass

    def eq_logpw(self, **kwargs):
        logpw = -torch.sum(self.weight_decay * .5 * (self.weight.pow(2)))
        logpb = 0
        if self.bias is not None:
            logpb = -torch.sum(self.weight_decay * .5 * (self.bias.pow(2)))
        return logpw + logpb

    def eq_logqw(self):
        return 0.

    def kldiv_aux(self):
        return 0.

    def kldiv(self):
        return self.eq_logpw() - self.eq_logqw() + self.kldiv_aux()

    def forward(self, input):
        output = input.mm(self.weight)
        if self.bias is not None:
            output.add_(self.bias.view(1, self.out_features).expand_as(output))
        return output

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
            + str(self.in_features) + ' -> ' \
            + str(self.out_features) + ')'
Beispiel #2
0
class stimGLM(Poisson):
    def __init__(self, input_dim=(12, 1),
        num_directions=12,
        output_dim=128,
        d2t=0.1,
        **kwargs):

        super().__init__()
        self.save_hyperparameters()
        
        self.directionTuning = Parameter(torch.Tensor( size = (self.hparams.num_directions,self.hparams.output_dim) ))
        self.directionKernel = Parameter(torch.Tensor( size = (self.hparams.input_dim[0], self.hparams.output_dim) ))
        self.bias = Parameter(torch.Tensor( size = (1, self.hparams.output_dim) ))
        self.spikeNL = nn.Softplus() 
        self.directionTuning.data = torch.randn(self.directionTuning.shape)
        self.directionKernel.data = torch.randn(self.directionKernel.shape)
        self.bias.data = torch.rand(self.bias.shape)
        
    def regularizer(self):
        d2tdir = self.directionKernel.diff(axis=0).pow(2).sum()
        l2dir = self.directionTuning.pow(2).sum()
        return self.hparams.d2t * d2tdir + self.hparams.d2t * l2dir
        # self.contrast.weight.data.diff()

    def forward(self, sample):
        x = torch.einsum('nld,lc->ndc', sample['direction'], self.directionKernel)
        x = torch.einsum('ndc,dc->nc', x, self.directionTuning)
        x = self.spikeNL(x + self.bias)
        return x
Beispiel #3
0
class Lap2D(nn.Module):
    """
    Isotropic (Lap2D shape) Lap2D filter 
    """
    def __init__(self, w, n_gaussian,learn_amplitude=False):
        super(Lap2D, self).__init__()
        self.xes = torch.FloatTensor(range(int(-w / 2)+1,  int(w / 2) + 1)).unsqueeze(-1)
        self.xes = self.xes.repeat(self.xes.size(0), 1, n_gaussian)
        self.yes = self.xes.transpose(1, 0)

        self.xypod = Parameter(self.xes * self.yes, requires_grad=False)
        self.xes = Parameter(self.xes.pow(2), requires_grad=False)
        self.yes = Parameter(self.yes.pow(2), requires_grad=False)
        self.padding = int(w / 2)
        self.s = Parameter(torch.randn(n_gaussian).float(), requires_grad=True)
        print("Current Lap2D")
        
    def weights_init(self):
        self.s.data.normal_(1.,0.3)

    def get_gaussian(self,s):
        #return  (- (self.xes + self.yes) / (2 * s.pow(2))).exp()/(2.4569*s)
        return  (- (self.xes + self.yes)*s.pow(2) / 2).exp()/(2.4569)*s
        
    def get_filter(self, s=None):
        """     
        :param s: 
        :param amplitude: 
        :return: 
        """
        if s is None:
            s = self.s
            
        eps=1e-3
        k=(self.s.pow(2)/eps) 
        filters = self.get_gaussian(self.s)\
                          - self.get_gaussian(self.s+eps) 
        
        return (k*filters).transpose(0, 2).unsqueeze(1).contiguous()

    def forward(self, x):
        filters = self.get_filter(self.s)
        return F.conv2d(x, filters, padding=self.padding, groups=x.size(1))  
class AngleSoftmax(nn.Module):
    def __init__(self,
                 input_size,
                 output_size,
                 normalize=True,
                 m=4,
                 lambda_max=1000.0,
                 lambda_min=5.0,
                 power=1.0,
                 gamma=0.1,
                 loss_weight=1.0):
        """
        :param input_size: Input channel size.
        :param output_size: Number of Class.
        :param normalize: Whether do weight normalization.
        :param m: An integer, specifying the margin type, take value of [0,1,2,3,4,5].
        :param lambda_max: Starting value for lambda.
        :param lambda_min: Minimum value for lambda.
        :param power: Decreasing strategy for lambda.
        :param gamma: Decreasing strategy for lambda.
        :param loss_weight: Loss weight for this loss.
        """
        super(AngleSoftmax, self).__init__()
        self.loss_weight = loss_weight
        self.normalize = normalize
        self.weight = Parameter(torch.Tensor(int(output_size), input_size))
        nn.init.kaiming_uniform_(self.weight, 1.0)
        self.m = m

        self.it = 0
        self.LambdaMin = lambda_min
        self.LambdaMax = lambda_max
        self.gamma = gamma
        self.power = power

    def forward(self, x, y):

        if self.normalize:
            wl = self.weight.pow(2).sum(1).pow(0.5)
            wn = self.weight / wl.view(-1, 1)
            self.weight.data.copy_(wn.data)
        if self.training:
            lamb = max(self.LambdaMin,
                       self.LambdaMax / (1 + self.gamma * self.it)**self.power)
            self.it += 1
            phi_kernel = PhiKernel(self.m, lamb)
            feat = phi_kernel(x, self.weight, y)
            loss = F.nll_loss(F.log_softmax(feat, dim=1), y)
        else:
            feat = x.mm(self.weight.t())
            self.prob = F.log_softmax(feat, dim=1)
            loss = F.nll_loss(self.prob, y)

        return loss.mul_(self.loss_weight)
Beispiel #5
0
class VBLinear(nn.Module):
    def __init__(self, in_features, out_features, prior_prec=10, map=True):
        super(VBLinear, self).__init__()
        self.n_in = in_features
        self.n_out = out_features

        self.prior_prec = prior_prec
        self.map = map

        self.bias = nn.Parameter(th.Tensor(out_features))
        self.mu_w = Parameter(th.Tensor(out_features, in_features))
        self.logsig2_w = nn.Parameter(th.Tensor(out_features, in_features))
        self.reset_parameters()

    def reset_parameters(self):
        # TODO: Adapt to the newest pytorch initializations
        stdv = 1. / math.sqrt(self.mu_w.size(1))
        self.mu_w.data.normal_(0, stdv)
        self.logsig2_w.data.zero_().normal_(-9, 0.001)  # var init via Louizos
        self.bias.data.zero_()

    def KL(self, loguniform=False):
        if loguniform:
            k1 = 0.63576
            k2 = 1.87320
            k3 = 1.48695
            log_alpha = self.logsig2_w - 2 * th.log(self.mu_w.abs() + 1e-8)
            kl = -th.sum(k1 * th.sigmoid(k2 + k3 * log_alpha) -
                         0.5 * F.softplus(-log_alpha) - k1)
        else:
            logsig2_w = self.logsig2_w.clamp(-11, 11)
            kl = 0.5 * (self.prior_prec *
                        (self.mu_w.pow(2) + logsig2_w.exp()) - logsig2_w - 1 -
                        np.log(self.prior_prec)).sum()
        return kl

    def forward(self, input):
        # Sampling free forward pass only if MAP prediction and no training rounds
        if self.map and not self.training:
            return F.linear(input, self.mu_w, self.bias)
        else:
            mu_out = F.linear(input, self.mu_w, self.bias)
            logsig2_w = self.logsig2_w.clamp(-11, 11)
            s2_w = logsig2_w.exp()
            var_out = F.linear(input.pow(2), s2_w) + 1e-8
            return mu_out + var_out.sqrt() * th.randn_like(mu_out)

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.n_in) + ' -> ' \
               + str(self.n_out) + ')'
class WeightNormalizedLinear(Module):
    def __init__(self,
                 in_features,
                 out_features,
                 scale=True,
                 bias=True,
                 init_factor=1,
                 init_scale=1):
        super(WeightNormalizedLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = Parameter(torch.zeros(1, out_features))
        else:
            self.register_parameter('bias', None)
        if scale:
            self.scale = Parameter(
                torch.Tensor(1, out_features).fill_(init_scale))
        else:
            self.register_parameter('scale', None)
        self.reset_parameters(init_factor)

    def reset_parameters(self, factor):
        stdv = 1. * factor / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def weight_norm(self):
        return self.weight.pow(2).sum(1).add(1e-6).sqrt()

    def norm_scale_bias(self, input):
        output = input.div(self.weight_norm().transpose(0, 1).expand_as(input))
        if self.scale is not None:
            output = output.mul(self.scale.expand_as(input))
        if self.bias is not None:
            output = output.add(self.bias.expand_as(input))
        return output

    def forward(self, input):
        return self.norm_scale_bias(F.linear(input, self.weight))

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
            + str(self.in_features) + ' -> ' \
            + str(self.out_features) + ')'
Beispiel #7
0
class BinaryGatedLinear(Module):
    """
    Linear layer with stochastic binary gates
    """
    def __init__(self,
                 in_features,
                 out_features,
                 l0_strength=1.,
                 l2_strength=1.,
                 learn_weight=True,
                 bias=True,
                 droprate_init=0.5,
                 random_weight=True,
                 deterministic=False,
                 use_baseline_bias=False,
                 optimize_inference=False,
                 one_sample_per_item=False,
                 **kwargs):
        """
        :param in_features: Input dimensionality
        :param out_features: Output dimensionality
        :param bias: Whether we use a bias
        :param l2_strength: Strength of the L2 penalty
        :param droprate_init: Dropout rate that the gates will be initialized to
        :param l0_strength: Strength of the L0 penalty
        """
        super(BinaryGatedLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.l0_strength = l0_strength
        self.l2_strength = l2_strength
        self.deterministic = deterministic
        self.use_baseline_bias = use_baseline_bias
        self.optimize_inference = optimize_inference
        self.one_sample_per_item = one_sample_per_item

        self.random_weight = random_weight
        if random_weight:
            exc_weight = torch.Tensor(out_features, in_features)
            inh_weight = torch.Tensor(out_features, in_features)
        else:
            exc_weight = torch.ones(out_features, in_features)
            inh_weight = torch.ones(out_features, in_features)

        if learn_weight:
            self.exc_weight = Parameter(exc_weight)
            self.inh_weight = Parameter(inh_weight)
        else:
            self.register_buffer("exc_weight", exc_weight)
            self.register_buffer("inh_weight", inh_weight)

        self.exc_p1 = Parameter(torch.Tensor(out_features, in_features))
        self.inh_p1 = Parameter(torch.Tensor(out_features, in_features))

        self.droprate_init = droprate_init if droprate_init != 0. else 0.5
        self.use_bias = bias
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        self.reset_parameters()

    def reset_parameters(self):
        if self.random_weight:
            init.kaiming_normal_(self.exc_weight, mode="fan_out")
            init.kaiming_normal_(self.inh_weight, mode="fan_out")
            self.exc_weight.data.abs_()
            self.inh_weight.data.abs_()
        self.exc_p1.data.normal_(1 - self.droprate_init, 1e-2)
        self.inh_p1.data.normal_(1 - self.droprate_init, 1e-2)
        if self.use_bias:
            self.bias.data.fill_(0)

    def constrain_parameters(self, **kwargs):
        self.exc_weight.data.clamp_(min=0.)
        self.inh_weight.data.clamp_(min=0.)

    def get_gate_probabilities(self):
        exc_p1 = torch.clamp(self.exc_p1.data, min=0., max=1.)
        inh_p1 = torch.clamp(self.inh_p1.data, min=0., max=1.)
        return exc_p1, inh_p1

    def weight_size(self):
        return self.exc_weight.size()

    def regularization(self):
        """
        Expected L0 norm under the stochastic gates, takes into account and
        re-weights also a potential L2 penalty
        """
        if self.l0_strength > 0 or self.l2_strength > 0:
            # Clamp these, but do it in a way that still always propagates the
            # gradient.
            exc_p1 = self.exc_p1.clone()
            torch.clamp(exc_p1.data, min=0, max=1, out=exc_p1.data)
            inh_p1 = self.inh_p1.clone()
            torch.clamp(inh_p1.data, min=0, max=1, out=inh_p1.data)

            if self.l2_strength == 0:
                return self.l0_strength * (exc_p1 + inh_p1).sum()
            else:
                exc_weight_decay_ungated = (.5 * self.l2_strength *
                                            self.exc_weight.pow(2))
                inh_weight_decay_ungated = (.5 * self.l2_strength *
                                            self.inh_weight.pow(2))
                exc_weight_l2_l0 = torch.sum(
                    (exc_weight_decay_ungated + self.l0_strength) * exc_p1)
                inh_weight_l2_l0 = torch.sum(
                    (inh_weight_decay_ungated + self.l0_strength) * inh_p1)
                bias_l2 = (0 if not self.use_bias else torch.sum(
                    .5 * self.l2_strength * self.bias.pow(2)))
                return exc_weight_l2_l0 + inh_weight_l2_l0 + bias_l2
        else:
            return 0

    def get_inference_mask(self):
        exc_p1, inh_p1 = self.get_gate_probabilities()

        if self.deterministic:
            exc_mask = (exc_p1 >= 0.5).float()
            inh_mask = (inh_p1 >= 0.5).float()
            return exc_mask, inh_mask
        else:
            exc_count1 = exc_p1.sum(dim=1).round().int()
            inh_count1 = inh_p1.sum(dim=1).round().int()

            # pytorch doesn't offer topk with varying k values.
            exc_mask = torch.zeros_like(exc_p1)
            inh_mask = torch.zeros_like(inh_p1)
            for i in range(exc_count1.size()[0]):
                _, exc_indices = torch.topk(exc_p1[i], exc_count1[i].item())
                _, inh_indices = torch.topk(inh_p1[i], inh_count1[i].item())
                exc_mask[i].scatter_(-1, exc_indices, 1)
                inh_mask[i].scatter_(-1, inh_indices, 1)

            return exc_mask, inh_mask

    def sample_weight_and_bias(self):
        if self.training or not self.optimize_inference:
            w = (sample_weight(self.exc_p1, self.exc_weight,
                               self.deterministic) -
                 sample_weight(self.inh_p1, self.inh_weight,
                               self.deterministic))
        else:
            exc_mask, inh_mask = self.get_inference_mask()
            w = exc_mask * self.exc_weight - inh_mask * self.inh_weight

        b = None
        if self.use_baseline_bias:
            b = -w.sum(dim=-1) / 2

        if self.use_bias:
            b = (b + self.bias if b is not None else self.bias)

        return w, b

    def forward(self, x):
        if self.one_sample_per_item and self.training and len(x.size()) > 1:
            results = []
            for i in range(x.size(0)):
                w, b = self.sample_weight_and_bias()
                results.append(F.linear(x[i:i + 1], w, b))
            return torch.cat(results)
        else:
            w, b = self.sample_weight_and_bias()
            return F.linear(x, w, b)
            return self._forward(x)

    def get_expected_nonzeros(self):
        exc_p1, inh_p1 = self.get_gate_probabilities()

        # Flip two coins with probabilities pi_1 and pi_2. What is the
        # probability one of them is 1?
        #
        # 1 - (1 - pi_1)*(1 - pi_2)
        # = 1 - 1 + pi_1 + pi_2 - pi_1*pi_2
        # = pi_1 + pi_2 - pi_1*pi_2
        p1 = exc_p1 + inh_p1 - (exc_p1 * inh_p1)

        return p1.sum(dim=1).detach()

    def get_inference_nonzeros(self):
        exc_mask, inh_mask = self.get_inference_mask()

        return torch.sum(exc_mask.int() | inh_mask.int(), dim=1)

    def count_inference_flops(self):
        # For each unit, multiply with its n inputs then do n - 1 additions.
        # To capture the -1, subtract it, but only in cases where there is at
        # least one weight.
        nz_by_unit = self.get_inference_nonzeros()
        multiplies = torch.sum(nz_by_unit)
        adds = multiplies - torch.sum(nz_by_unit > 0)
        return multiplies.item(), adds.item()
Beispiel #8
0
class CLTLinear(nn.Module):
    def __init__(self,
                 in_features,
                 out_features,
                 prior_prec=10,
                 relu_act=True,
                 elu_act=False):
        super(CLTLinear, self).__init__()
        self.n_in = in_features
        self.n_out = out_features

        self.prior_prec = prior_prec

        assert not (
            relu_act and elu_act
        )  # A single layer can only do either relu or elu activation
        self.relu_act = relu_act
        self.elu_act = elu_act

        self.bias = nn.Parameter(th.Tensor(out_features))
        self.mu_w = Parameter(th.Tensor(out_features, in_features))
        self.logsig2_w = nn.Parameter(th.Tensor(out_features, in_features))
        self.reset_parameters()

    def reset_parameters(self):
        # TODO: Adapt to the newest pytorch initializations
        stdv = 1. / math.sqrt(self.mu_w.size(1))
        self.mu_w.data.normal_(0, stdv)
        self.logsig2_w.data.zero_().normal_(-9, 0.001)
        self.bias.data.zero_()

    def KL(self, loguniform=False):
        if loguniform:
            k1 = 0.63576
            k2 = 1.87320
            k3 = 1.48695
            log_alpha = self.logsig2_w - 2 * th.log(self.mu_w.abs() + 1e-8)
            kl = -th.sum(k1 * F.sigmoid(k2 + k3 * log_alpha) -
                         0.5 * F.softplus(-log_alpha) - k1)
        else:
            logsig2_w = self.logsig2_w.clamp(-11, 11)
            kl = 0.5 * (self.prior_prec *
                        (self.mu_w.pow(2) + logsig2_w.exp()) - logsig2_w - 1 -
                        np.log(self.prior_prec)).sum()
        return kl

    def cdf(self, x, mu=0., sig=1.):
        return 0.5 * (1 + th.erf((x - mu) / (sig * math.sqrt(2))))

    def pdf(self, x, mu=0., sig=1.):
        return (1 / (math.sqrt(2 * math.pi) * sig)) * th.exp(-0.5 * (
            (x - mu) / sig).pow(2))

    def relu_moments(self, mu, sig):
        alpha = mu / sig
        cdf = self.cdf(alpha)
        pdf = self.pdf(alpha)
        relu_mean = mu * cdf + sig * pdf
        relu_var = (sig.pow(2) +
                    mu.pow(2)) * cdf + mu * sig * pdf - relu_mean.pow(2)
        relu_var.clamp_(1e-8)  # Avoid negative variance due to numerics
        return relu_mean, relu_var

    def elu_moments_orig(self, mu, sig):
        # the original method without simplifications
        sig2 = sig.pow(2)
        elu_mean = th.exp(mu.clamp_max(10) + sig2 / 2) * self.cdf(
            -(mu + sig2) / sig) - self.cdf(-mu / sig)
        elu_mean += mu * self.cdf(mu / sig) + sig * self.pdf(mu / sig)
        elu_var = th.exp(2 * mu.clamp_max(10) + 2 * sig2) * self.cdf(
            -(mu + 2 * sig2) / sig)
        elu_var += -2 * th.exp(mu.clamp_max(10) + sig2 / 2) * self.cdf(
            -(mu + sig2) / sig)
        elu_var += self.cdf(-mu / sig)
        elu_var += (sig2 + mu.pow(2)) * self.cdf(
            mu / sig) + mu * sig * self.pdf(mu / sig)
        elu_var += -elu_mean.pow(2)
        elu_var.clamp_min_(1e-8)  # Avoid negative variance due to numerics
        return elu_mean, elu_var

    def elu_moments(self, mu, sig):
        # NOTE: For now it is without alpha or the selu extension!
        # Note: Takes roughly 3x as much time as the relu
        # Clamp the mus to avoid problems in the expectation
        sig2 = sig.pow(2)
        alpha = mu / sig

        cdf_alpha = self.cdf(alpha)
        pdf_alpha = self.pdf(alpha)
        cdf_malpha = 1 - cdf_alpha
        cdf_malphamsig = self.cdf(-alpha - sig)

        elu_mean = th.exp(mu.clamp_max(10) +
                          sig2 / 2) * cdf_malphamsig - cdf_malpha
        elu_mean += mu * cdf_alpha + sig * pdf_alpha

        elu_var = th.exp(2 * mu.clamp_max(10) + 2 * sig2) * self.cdf(-alpha -
                                                                     2 * sig)
        elu_var += -2 * th.exp(mu.clamp_max(10) + sig2 / 2) * cdf_malphamsig
        elu_var += cdf_malpha
        elu_var += (sig2 + mu.pow(2)) * cdf_alpha + mu * sig * pdf_alpha
        elu_var += -elu_mean.pow(2)
        elu_var.clamp_min_(1e-8)  # Avoid negative variance due to numerics
        return elu_mean, elu_var

    def forward(self, mu_inp, var_inp=None):
        s2_w = self.logsig2_w.exp()

        mu_out = F.linear(mu_inp, self.mu_w, self.bias)
        if var_inp is None:
            var_out = F.linear(mu_inp.pow(2), s2_w) + 1e-8
        else:
            var_out = F.linear(var_inp + mu_inp.pow(2), s2_w) + F.linear(
                var_inp, self.mu_w.pow(2)) + 1e-8

        if self.relu_act:
            mu_out, var_out = self.relu_moments(mu_out, var_out.sqrt())

        if self.elu_act:
            mu_out, var_out = self.elu_moments(mu_out, var_out.sqrt())

        return mu_out, var_out  # + 1e-8 Already provided in the moment computation

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.n_in) + ' -> ' \
               + str(self.n_out) \
               + f", activation={self.relu_act or self.elu_act}" \
               + f" ({'relu' if self.relu_act else ('elu' if self.elu_act else '')}))"
Beispiel #9
0
class BinaryGatedConv2d(Module):
    """
    Convolutional layer with binary stochastic gates
    """
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 learn_weight=True,
                 bias=True,
                 droprate_init=0.5,
                 l2_strength=1.,
                 l0_strength=1.,
                 random_weight=True,
                 deterministic=False,
                 use_baseline_bias=False,
                 optimize_inference=True,
                 one_sample_per_item=False,
                 **kwargs):
        """
        :param in_channels: Number of input channels
        :param out_channels: Number of output channels
        :param kernel_size: Size of the kernel
        :param stride: Stride for the convolution
        :param padding: Padding for the convolution
        :param dilation: Dilation factor for the convolution
        :param groups: How many groups we will assume in the convolution
        :param bias: Whether we will use a bias
        :param droprate_init: Dropout rate that the gates will be initialized to
        :param l2_strength: Strength of the L2 penalty
        :param l0_strength: Strength of the L0 penalty
        """
        super(BinaryGatedConv2d, self).__init__()
        if in_channels % groups != 0:
            raise ValueError("in_channels must be divisible by groups")
        if out_channels % groups != 0:
            raise ValueError("out_channels must be divisible by groups")
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = pair(kernel_size)
        self.stride = pair(stride)
        self.padding = pair(padding)
        self.dilation = pair(dilation)
        self.output_padding = pair(0)
        self.groups = groups
        self.l2_strength = l2_strength
        self.l0_strength = l0_strength
        self.droprate_init = droprate_init if droprate_init != 0. else 0.5
        self.deterministic = deterministic
        self.use_baseline_bias = use_baseline_bias
        self.optimize_inference = optimize_inference
        self.one_sample_per_item = one_sample_per_item

        self.random_weight = random_weight
        if random_weight:
            exc_weight = torch.Tensor(out_channels, in_channels // groups,
                                      *self.kernel_size)
            inh_weight = torch.Tensor(out_channels, in_channels // groups,
                                      *self.kernel_size)
        else:
            exc_weight = torch.ones(out_channels, in_channels // groups,
                                    *self.kernel_size)
            inh_weight = torch.ones(out_channels, in_channels // groups,
                                    *self.kernel_size)

        if learn_weight:
            self.exc_weight = Parameter(exc_weight)
            self.inh_weight = Parameter(inh_weight)
        else:
            self.register_buffer("exc_weight", exc_weight)
            self.register_buffer("inh_weight", inh_weight)
        self.exc_p1 = Parameter(
            torch.Tensor(out_channels, in_channels // groups,
                         *self.kernel_size))
        self.inh_p1 = Parameter(
            torch.Tensor(out_channels, in_channels // groups,
                         *self.kernel_size))
        self.dim_z = out_channels
        self.input_shape = None

        self.use_bias = bias
        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))

        self.reset_parameters()

    def reset_parameters(self):
        if self.random_weight:
            init.kaiming_normal_(self.exc_weight, mode="fan_out")
            init.kaiming_normal_(self.inh_weight, mode="fan_out")
            self.exc_weight.data.abs_()
            self.inh_weight.data.abs_()
        self.exc_p1.data.normal_(1 - self.droprate_init, 1e-2)
        self.inh_p1.data.normal_(1 - self.droprate_init, 1e-2)

        if self.use_bias:
            self.bias.data.fill_(0)

    def constrain_parameters(self, **kwargs):
        self.exc_weight.data.clamp_(min=0.)
        self.inh_weight.data.clamp_(min=0.)

    def weight_size(self):
        return self.exc_weight.size()

    def regularization(self):
        """
        Expected L0 norm under the stochastic gates, takes into account and
        re-weights also a potential L2 penalty
        """

        if self.l0_strength > 0 or self.l2_strength > 0:
            # Clamp these, but do it in a way that still always propagates the
            # gradient.
            exc_p1 = self.exc_p1.clone()
            torch.clamp(exc_p1.data, min=0, max=1, out=exc_p1.data)
            inh_p1 = self.inh_p1.clone()
            torch.clamp(inh_p1.data, min=0, max=1, out=inh_p1.data)

            if self.l2_strength == 0:
                return self.l0_strength * (exc_p1 + inh_p1).sum()
            else:
                exc_weight_decay_ungated = (.5 * self.l2_strength *
                                            self.exc_weight.pow(2))
                inh_weight_decay_ungated = (.5 * self.l2_strength *
                                            self.inh_weight.pow(2))
                exc_weight_l2_l0 = torch.sum(
                    (exc_weight_decay_ungated + self.l0_strength) * exc_p1)
                inh_weight_l2_l0 = torch.sum(
                    (inh_weight_decay_ungated + self.l0_strength) * inh_p1)
                bias_l2 = (0 if not self.use_bias else torch.sum(
                    .5 * self.l2_strength * self.bias.pow(2)))
                return exc_weight_l2_l0 + inh_weight_l2_l0 + bias_l2
        else:
            return 0

    def get_gate_probabilities(self):
        exc_p1 = torch.clamp(self.exc_p1.data, min=0., max=1.)
        inh_p1 = torch.clamp(self.inh_p1.data, min=0., max=1.)
        return exc_p1, inh_p1

    def get_inference_mask(self):
        exc_p1, inh_p1 = self.get_gate_probabilities()

        if self.deterministic:
            exc_mask = (exc_p1 >= 0.5).float()
            inh_mask = (inh_p1 >= 0.5).float()
            return exc_mask, inh_mask
        else:
            exc_count1 = exc_p1.sum(
                dim=tuple(range(1, len(exc_p1.shape)))).round().int()
            inh_count1 = inh_p1.sum(
                dim=tuple(range(1, len(inh_p1.shape)))).round().int()

            # pytorch doesn't offer topk with varying k values.
            exc_mask = torch.zeros_like(exc_p1)
            inh_mask = torch.zeros_like(inh_p1)
            for i in range(exc_count1.size()[0]):
                _, exc_indices = torch.topk(exc_p1[i].flatten(),
                                            exc_count1[i].item())
                _, inh_indices = torch.topk(inh_p1[i].flatten(),
                                            inh_count1[i].item())
                exc_mask[i].flatten().scatter_(-1, exc_indices, 1)
                inh_mask[i].flatten().scatter_(-1, inh_indices, 1)

            return exc_mask, inh_mask

    def sample_weight_and_bias(self, samples=1):
        if self.training or not self.optimize_inference:
            w = (sample_weight(self.exc_p1, self.exc_weight,
                               self.deterministic, samples) -
                 sample_weight(self.inh_p1, self.inh_weight,
                               self.deterministic, samples))
        else:
            exc_mask, inh_mask = self.get_inference_mask()
            w = exc_mask * self.exc_weight - inh_mask * self.inh_weight

        b = None
        if self.use_baseline_bias:
            b = -w.sum(dim=(-3, -2, -1)) / 2

        if self.use_bias:
            b = (b + self.bias if b is not None else self.bias)

        return w, b

    def forward(self, x):
        if self.input_shape is None:
            self.input_shape = x.size()

        if self.one_sample_per_item and self.training and len(x.size()) > 3:
            w, b = self.sample_weight_and_bias(x.size(0))

            if self.use_baseline_bias:
                b = b.view(x.size(0) * self.out_channels)
            else:
                b = b.repeat(x.size(0))

            x_ = x.view(1, x.size(0) * x.size(1), *x.size()[2:])
            w_ = w.view(w.size(0) * w.size(1), *w.size()[2:])
            result = F.conv2d(x_, w_, b, self.stride, self.padding,
                              self.dilation,
                              x.size(0) * self.groups)

            return result.view(x.size(0), self.out_channels,
                               *result.size()[2:])
        else:
            w, b = self.sample_weight_and_bias()
            return F.conv2d(x, w, b, self.stride, self.padding, self.dilation,
                            self.groups)

    def get_expected_nonzeros(self):
        exc_p1, inh_p1 = self.get_gate_probabilities()

        # Flip two coins with probabilities pi_1 and pi_2. What is the
        # probability one of them is 1?
        #
        # 1 - (1 - pi_1)*(1 - pi_2)
        # = 1 - 1 + pi_1 + pi_2 - pi_1*pi_2
        # = pi_1 + pi_2 - pi_1*pi_2
        p1 = exc_p1 + inh_p1 - (exc_p1 * inh_p1)

        return p1.sum(dim=tuple(range(1, len(p1.shape)))).detach()

    def get_inference_nonzeros(self):
        exc_mask, inh_mask = self.get_inference_mask()
        return torch.sum(exc_mask.int() | inh_mask.int(),
                         dim=tuple(range(1, len(exc_mask.shape))))

    def count_inference_flops(self):
        # For each unit, multiply with n inputs then do n - 1 additions.
        # Only subtract 1 in cases where is at least one weight.
        nz_by_unit = self.get_inference_nonzeros()
        multiplies_per_instance = torch.sum(nz_by_unit)
        adds_per_instance = multiplies_per_instance - torch.sum(nz_by_unit > 0)

        # for rows
        instances = ((self.input_shape[-2] - self.kernel_size[0] +
                      2 * self.padding[0]) / self.stride[0]) + 1
        # multiplying with cols
        instances *= ((self.input_shape[-1] - self.kernel_size[1] +
                       2 * self.padding[1]) / self.stride[1]) + 1

        multiplies = multiplies_per_instance * instances
        adds = adds_per_instance * instances

        return multiplies.item(), adds.item()
Beispiel #10
0
class _ConvNdGroupNJ(Module):
    """Convolutional Group Normal-Jeffrey's layers (aka Group Variational Dropout).

    References:
    [1] Kingma, Diederik P., Tim Salimans, and Max Welling. "Variational dropout and the local reparameterization trick." NIPS (2015).
    [2] Molchanov, Dmitry, Arsenii Ashukha, and Dmitry Vetrov. "Variational Dropout Sparsifies Deep Neural Networks." ICML (2017).
    [3] Louizos, Christos, Karen Ullrich, and Max Welling. "Bayesian Compression for Deep Learning." NIPS (2017).
    """
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding,
                 groups, bias, init_weight, init_bias, cuda=False, clip_var=None):
        super(_ConvNdGroupNJ, self).__init__()
        if in_channels % groups != 0:
            raise ValueError('in_channels must be divisible by groups')
        if out_channels % groups != 0:
            raise ValueError('out_channels must be divisible by groups')
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.transposed = transposed
        self.output_padding = output_padding
        self.groups = groups

        self.cuda = cuda
        self.clip_var = clip_var
        self.deterministic = False  # flag is used for compressed inference

        if transposed:
            self.weight_mu = Parameter(torch.Tensor(
                in_channels, out_channels // groups, *kernel_size))
            self.weight_logvar = Parameter(torch.Tensor(
                in_channels, out_channels // groups, *kernel_size))
        else:
            self.weight_mu = Parameter(torch.Tensor(
                out_channels, in_channels // groups, *kernel_size))
            self.weight_logvar = Parameter(torch.Tensor(
                out_channels, in_channels // groups, *kernel_size))

        self.bias_mu = Parameter(torch.Tensor(out_channels))
        self.bias_logvar = Parameter(torch.Tensor(out_channels))

        self.z_mu = Parameter(torch.Tensor(self.out_channels))
        self.z_logvar = Parameter(torch.Tensor(self.out_channels))

        self.reset_parameters(init_weight, init_bias)

        # activations for kl
        self.sigmoid = nn.Sigmoid()
        self.softplus = nn.Softplus()
        # numerical stability param
        self.epsilon = 1e-8

    def reset_parameters(self, init_weight, init_bias):
        # init means
        n = self.in_channels
        for k in self.kernel_size:
            n *= k
        stdv = 1. / math.sqrt(n)

        # init means
        if init_weight is not None:
            self.weight_mu.data = init_weight
        else:
            self.weight_mu.data.uniform_(-stdv, stdv)

        if init_bias is not None:
            self.bias_mu.data = init_bias
        else:
            self.bias_mu.data.fill_(0)

        # inti z
        self.z_mu.data.normal_(1, 1e-2)

        # init logvars
        self.z_logvar.data.normal_(-9, 1e-2)
        self.weight_logvar.data.normal_(-9, 1e-2)
        self.bias_logvar.data.normal_(-9, 1e-2)

    def clip_variances(self):
        if self.clip_var:
            self.weight_logvar.data.clamp_(max=math.log(self.clip_var))
            self.bias_logvar.data.clamp_(max=math.log(self.clip_var))

    def get_log_dropout_rates(self):
        log_alpha = self.z_logvar - torch.log(self.z_mu.pow(2) + self.epsilon)
        return log_alpha

    def compute_posterior_params(self):
        weight_var, z_var = self.weight_logvar.exp(), self.z_logvar.exp()
        self.post_weight_var = self.z_mu.pow(2) * weight_var + z_var * self.weight_mu.pow(2) + z_var * weight_var
        self.post_weight_mu = self.weight_mu * self.z_mu
        return self.post_weight_mu, self.post_weight_var

    def kl_divergence(self):
        # KL(q(z)||p(z))
        # we use the kl divergence approximation given by [2] Eq.(14)
        k1, k2, k3 = 0.63576, 1.87320, 1.48695
        log_alpha = self.get_log_dropout_rates()
        KLD = -torch.sum(k1 * self.sigmoid(k2 + k3 * log_alpha) - 0.5 * self.softplus(-log_alpha) - k1)

        # KL(q(w|z)||p(w|z))
        # we use the kl divergence given by [3] Eq.(8)
        KLD_element = - 0.5 * self.weight_logvar + 0.5 * (self.weight_logvar.exp() + self.weight_mu.pow(2)) - 0.5
        KLD += torch.sum(KLD_element)

        # KL bias
        KLD_element = - 0.5 * self.bias_logvar + 0.5 * (self.bias_logvar.exp() + self.bias_mu.pow(2)) - 0.5
        KLD += torch.sum(KLD_element)

        return KLD

    def __repr__(self):
        s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}'
             ', stride={stride}')
        if self.padding != (0,) * len(self.padding):
            s += ', padding={padding}'
        if self.dilation != (1,) * len(self.dilation):
            s += ', dilation={dilation}'
        if self.output_padding != (0,) * len(self.output_padding):
            s += ', output_padding={output_padding}'
        if self.groups != 1:
            s += ', groups={groups}'
        if self.bias is None:
            s += ', bias=False'
        s += ')'
        return s.format(name=self.__class__.__name__, **self.__dict__)
class group_relaxed_TF1Conv2d(Module):
	"""Implementation of TF1 regularization for the feature maps of a convolutional layer"""
	def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True,
		lamba=1., alpha=1., beta=4., weight_decay = 1., **kwargs):
		"""
		:param in_channels: Number of input channels
		:param out_channels: Number of output channels
		:param kernel_size: size of the kernel
		:param stride: stride for the convolution
		:param padding: padding for the convolution
		:param dilation: dilation factor for the convolution
		:param groups: how many groups we will assume in the convolution
		:param bias: whether we will use a bias
		:param lamba: strength of the TFL regularization
		"""
		super(group_relaxed_TF1Conv2d, self).__init__()
		self.floatTensor = torch.FloatTensor if not torch.cuda.is_available() else torch.cuda.FloatTensor
		self.in_channels = in_channels
		self.out_channels = out_channels
		self.kernel_size = pair(kernel_size)
		self.stride = pair(stride)
		self.padding = pair(padding)
		self.dilation = pair(dilation)
		self.output_padding = pair(0)
		self.groups = groups
		self.lamba = lamba
		self.alpha = alpha
		self.beta = beta
		self.lamba1 = self.lamba/self.beta
		self.weight_decay = weight_decay
		self.weight = Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size))
		self.u = torch.rand(out_channels, in_channels // groups, *self.kernel_size)
		self.u = self.u.to('cuda')
		if bias:
			self.bias = Parameter(torch.Tensor(out_channels))
		else:
			self.register_parameter('bias', None)
		self.reset_parameters()
		self.input_shape = None
		print(self)

	def reset_parameters(self):
		init.kaiming_normal(self.weight, mode='fan_in')
		

		if self.bias is not None:
			self.bias.data.normal_(0,1e-2)

	def phi(self,x):
		phi_x = torch.acos(1-27*(self.lamba1*self.alpha*(self.alpha+1))/(2*(self.alpha+x.abs())**3))
		return phi_x

	def g(self,x):
		g_x = x.sign()*(2/3*(self.alpha + x.abs())*torch.cos(self.phi(x)/3)-2*self.alpha/3+x.abs()/3)
		return g_x


	def constrain_parameters(self, thres_std=1.):
		#self.weight.data = F.normalize(self.weight.data, p=2, dim=1)
		#print(torch.sum(self.weight.pow(2)))
		if self.lamba1 <= (self.alpha**2)/(2*(self.alpha+1)):
			t = self.lamba1*(self.alpha+1)/(self.alpha)
		else:
			t = np.sqrt(2*self.lamba1*(self.alpha+1))-self.alpha/2
		self.u.data = self.weight.data.clone()
		self.u.data[self.u.data.abs() <=t] = 0
		g_result = self.g(self.u)
		self.u.data[self.u.data.abs() > t] = g_result[self.u.data.abs() > t]

	def grow_beta(self, growth_factor):
		self.beta = self.beta*growth_factor
		self.lamba1 = self.lamba/self.beta

	def _reg_w(self, **kwargs):
		logpw = -self.beta*torch.sum(0.5*self.weight.add(-self.u).pow(2))-self.lamba*np.sqrt(self.in_channels*self.kernel_size[0]*self.kernel_size[1])*torch.sum(torch.pow(torch.sum(self.weight.pow(2),3).sum(2).sum(1),0.5))
		logpb = 0
		if self.bias is not None:
			logpb = - torch.sum(self.weight_decay * .5 * (self.bias.pow(2)))
		return logpw+logpb

	def regularization(self):
		return self._reg_w()

	def count_zero_u(self):
		total = np.prod(self.u.size())
		zero = total - self.u.nonzero().size(0)
		return zero

	def count_zero_w(self):
		return torch.sum((self.weight.abs()<1e-5).int()).item()

	def count_active_neuron(self):
		return torch.sum((torch.sum(self.weight.abs(),3).sum(2).sum(1)/(self.in_channels*self.kernel_size[0]*self.kernel_size[1]))>1e-5).item()

	def count_total_neuron(self):
		return self.out_channels


	def count_weight(self):
		return np.prod(self.u.size())

	def count_expected_flops_and_l0(self):
		#ppos = self.out_channels
		ppos = torch.sum(torch.sum(self.weight.abs(),3).sum(2).sum(1)>0.001).item()
		n = self.kernel_size[0]*self.kernel_size[1]*self.in_channels
		flops_per_instance = n+(n-1)

		num_instances_per_filter = ((self.input_shape[1] -self.kernel_size[0]+2*self.padding[0])/self.stride[0]) + 1
		num_instances_per_filter *=((self.input_shape[2] - self.kernel_size[1]+2*self.padding[1])/self.stride[1]) + 1

		flops_per_filter = num_instances_per_filter * flops_per_instance
		expected_flops = flops_per_filter*ppos
		expected_l0 = n*ppos

		if self.bias is not None:
			expected_flops += num_instances_per_filter*ppos
			expected_l0 += ppos
		return expected_flops, expected_l0

	def forward(self, input_):
		if self.input_shape is None:
			self.input_shape = input_.size()
		output = F.conv2d(input_, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
		return output

	def __repr__(self):
		s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size} '
			', stride={stride}')
		if self.padding != (0,) * len(self.padding):
			s += ', padding={padding}'
		if self.dilation != (1,) * len(self.dilation):
			s += ', dilation={dilation}'
		if self.output_padding != (0,) * len(self.output_padding):
			s += ', output_padding={output_padding}'
		if self.groups != 1:
			s += ', groups={groups}'
		if self.bias is None:
			s += ', bias=False'
		s += ')'
		return s.format(name=self.__class__.__name__, **self.__dict__)
class HardConcreteGatedConv2d(Module):
    """
    Convolutional layer with stochastic connections, as in
    https://arxiv.org/abs/1712.01312
    """
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
                 learn_weight=True,
                 droprate_init=0.5,
                 temperature=(2 / 3),
                 l2_strength=1.,
                 l0_strength=1.,
                 **kwargs):
        """
        :param in_channels: Number of input channels
        :param out_channels: Number of output channels
        :param kernel_size: Size of the kernel
        :param stride: Stride for the convolution
        :param padding: Padding for the convolution
        :param dilation: Dilation factor for the convolution
        :param groups: How many groups we will assume in the convolution
        :param bias: Whether we will use a bias
        :param droprate_init: Dropout rate that the L0 gates will be initialized
                              to
        :param temperature: Temperature of the concrete distribution
        :param l2_strength: Strength of the L2 penalty
        :param l0_strength: Strength of the L0 penalty
        """
        super(HardConcreteGatedConv2d, self).__init__()
        if in_channels % groups != 0:
            raise ValueError("in_channels must be divisible by groups")
        if out_channels % groups != 0:
            raise ValueError("out_channels must be divisible by groups")
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = pair(kernel_size)
        self.stride = pair(stride)
        self.padding = pair(padding)
        self.dilation = pair(dilation)
        self.output_padding = pair(0)
        self.groups = groups
        self.l2_strength = l2_strength
        self.l0_strength = l0_strength
        self.droprate_init = droprate_init if droprate_init != 0. else 0.5
        self.temperature = temperature
        self.floatTensor = (torch.FloatTensor if not torch.cuda.is_available()
                            else torch.cuda.FloatTensor)
        self.use_bias = False
        weight = torch.Tensor(out_channels, in_channels // groups,
                              *self.kernel_size)
        if learn_weight:
            self.weight = Parameter(weight)
        else:
            self.register_buffer("weight", weight)
        self.loga = Parameter(
            torch.Tensor(out_channels, in_channels // groups,
                         *self.kernel_size))
        self.dim_z = out_channels
        self.input_shape = None

        if bias:
            bias = torch.Tensor(out_channels)
            if learn_weight:
                self.bias = Parameter(bias)
            else:
                self.register_buffer("bias", bias)
            self.use_bias = True

        self.reset_parameters()

    def reset_parameters(self):
        init.kaiming_normal_(self.weight, mode="fan_in")

        self.loga.data.normal_(
            math.log(1 - self.droprate_init) - math.log(self.droprate_init),
            1e-2)

        if self.use_bias:
            self.bias.data.fill_(0)

    def constrain_parameters(self, **kwargs):
        self.loga.data.clamp_(min=math.log(1e-2), max=math.log(1e2))

    def cdf_qz(self, x):
        """Implements the CDF of the 'stretched' concrete distribution"""
        xn = (x - LIMIT_A) / (LIMIT_B - LIMIT_A)
        logits = math.log(xn) - math.log(1 - xn)
        return torch.sigmoid(logits * self.temperature - self.loga).clamp(
            min=EPSILON, max=1 - EPSILON)

    def quantile_concrete(self, x):
        """
        Implements the quantile, aka inverse CDF, of the 'stretched' concrete
        distribution
        """
        y = torch.sigmoid(
            (torch.log(x) - torch.log(1 - x) + self.loga) / self.temperature)
        return y * (LIMIT_B - LIMIT_A) + LIMIT_A

    def weight_size(self):
        return self.weight.size()

    def regularization(self):
        """
        Expected L0 norm under the stochastic gates, takes into account and
        re-weights also a potential L2 penalty
        """
        weight_decay_ungated = .5 * self.l2_strength * self.weight.pow(2)
        weight_l2_l0 = torch.sum(
            (weight_decay_ungated + self.l0_strength) * (1 - self.cdf_qz(0)))
        bias_l2 = (0 if not self.use_bias else torch.sum(.5 *
                                                         self.l2_strength *
                                                         self.bias.pow(2)))
        return weight_l2_l0 + bias_l2

    def count_inference_flops(self):
        # For each unit, multiply with n inputs then do n - 1 additions.
        # Only subtract 1 in cases where is at least one weight.
        nz_by_unit = self.get_inference_nonzeros()
        multiplies_per_instance = torch.sum(nz_by_unit)
        adds_per_instance = multiplies_per_instance - torch.sum(nz_by_unit > 0)

        # for rows
        instances = ((self.input_shape[-2] - self.kernel_size[0] +
                      2 * self.padding[0]) / self.stride[0]) + 1
        # multiplying with cols
        instances *= ((self.input_shape[-1] - self.kernel_size[1] +
                       2 * self.padding[1]) / self.stride[1]) + 1

        multiplies = multiplies_per_instance * instances
        adds = adds_per_instance * instances

        return multiplies.item(), adds.item()

    def count_expected_flops_and_l0(self):
        """
        Measures the expected floating point operations (FLOPs) and the expected
        L0 norm

        Copied from the original L0 paper code
        """
        ppos = torch.sum(1 - self.cdf_qz(0))
        # vector_length
        n = self.kernel_size[0] * self.kernel_size[1] * self.in_channels
        # (n: multiplications and n-1: additions)
        flops_per_instance = n + (n - 1)

        # for rows
        num_instances_per_filter = (
            (self.input_shape[1] - self.kernel_size[0] + 2 * self.padding[0]) /
            self.stride[0]) + 1
        # multiplying with cols
        num_instances_per_filter *= (
            (self.input_shape[2] - self.kernel_size[1] + 2 * self.padding[1]) /
            self.stride[1]) + 1

        flops_per_filter = num_instances_per_filter * flops_per_instance
        # multiply with number of filters
        expected_flops = flops_per_filter * ppos
        expected_l0 = n * ppos

        if self.use_bias:
            # since the gate is applied to the output we also reduce the bias
            # computation
            expected_flops += num_instances_per_filter * ppos
            expected_l0 += ppos

        return expected_flops.data[0], expected_l0.data[0]

    def get_eps(self, size):
        """Uniform random numbers for the concrete distribution"""
        eps = self.floatTensor(size).uniform_(EPSILON, 1 - EPSILON)
        eps = Variable(eps)
        return eps

    def sample_weight(self):
        if self.training:
            z = self.quantile_concrete(
                self.get_eps(self.floatTensor(self.loga.size())))
            mask = F.hardtanh(z, min_val=0, max_val=1)
        else:
            pi = torch.sigmoid(self.loga)
            mask = F.hardtanh(pi * (LIMIT_B - LIMIT_A) + LIMIT_A,
                              min_val=0,
                              max_val=1)

        return mask * self.weight

    def forward(self, x):
        if self.input_shape is None:
            self.input_shape = x.size()
        return F.conv2d(x, self.sample_weight(),
                        (self.bias if self.use_bias else None), self.stride,
                        self.padding, self.dilation, self.groups)

    def get_expected_nonzeros(self):
        expected_gates = 1 - self.cdf_qz(0)
        return expected_gates.sum(
            dim=tuple(range(1, len(expected_gates.shape)))).detach()

    def get_inference_nonzeros(self):
        inference_gates = F.hardtanh(torch.sigmoid(self.loga) *
                                     (LIMIT_B - LIMIT_A) + LIMIT_A,
                                     min_val=0,
                                     max_val=1)
        return (inference_gates > 0).sum(
            dim=tuple(range(1, len(inference_gates.shape)))).detach()

    def __repr__(self):
        s = (
            "{name}({in_channels}, {out_channels}, kernel_size={kernel_size}, "
            "stride={stride}, droprate_init={droprate_init}, "
            "temperature={temperature}, l2_strength={l2_strength}, "
            "l0_strength={l0_strength}")
        if self.padding != (0, ) * len(self.padding):
            s += ", padding={padding}"
        if self.dilation != (1, ) * len(self.dilation):
            s += ", dilation={dilation}"
        if self.output_padding != (0, ) * len(self.output_padding):
            s += ", output_padding={output_padding}"
        if self.groups != 1:
            s += ", groups={groups}"
        if not self.use_bias:
            s += ", bias=False"
        s += ")"
        return s.format(name=self.__class__.__name__, **self.__dict__)
Beispiel #13
0
class BinaryGatedConv2d(Module):
    """
    Convolutional layer with binary stochastic gates
    """
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 learn_weight=True,
                 bias=True,
                 droprate_init=0.5,
                 l2_strength=1.,
                 l0_strength=1.,
                 **kwargs):
        """
        :param in_channels: Number of input channels
        :param out_channels: Number of output channels
        :param kernel_size: Size of the kernel
        :param stride: Stride for the convolution
        :param padding: Padding for the convolution
        :param dilation: Dilation factor for the convolution
        :param groups: How many groups we will assume in the convolution
        :param bias: Whether we will use a bias
        :param droprate_init: Dropout rate that the gates will be initialized to
        :param l2_strength: Strength of the L2 penalty
        :param l0_strength: Strength of the L0 penalty
        """
        super(BinaryGatedConv2d, self).__init__()
        if in_channels % groups != 0:
            raise ValueError("in_channels must be divisible by groups")
        if out_channels % groups != 0:
            raise ValueError("out_channels must be divisible by groups")
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = pair(kernel_size)
        self.stride = pair(stride)
        self.padding = pair(padding)
        self.dilation = pair(dilation)
        self.output_padding = pair(0)
        self.groups = groups
        self.l2_strength = l2_strength
        self.l0_strength = l0_strength
        self.droprate_init = droprate_init if droprate_init != 0. else 0.5
        self.floatTensor = (torch.FloatTensor if not torch.cuda.is_available()
                            else torch.cuda.FloatTensor)
        self.use_bias = False
        weight = torch.Tensor(out_channels, in_channels // groups,
                              *self.kernel_size)
        if learn_weight:
            self.weight = Parameter(weight)
        else:
            self.register_buffer("weight", weight)
        self.logit_p1 = Parameter(
            torch.Tensor(out_channels, in_channels // groups,
                         *self.kernel_size))
        self.dim_z = out_channels
        self.input_shape = None

        if bias:
            b = torch.Tensor(out_channels)
            if learn_weight:
                self.bias = Parameter(b)
            else:
                self.register_buffer("bias", b)
            self.use_bias = True

        self.reset_parameters()

    def reset_parameters(self):
        init.kaiming_normal_(self.weight, mode="fan_in")
        self.logit_p1.data.normal_(
            math.log(1 - self.droprate_init) - math.log(self.droprate_init),
            1e-2)

        if self.use_bias:
            self.bias.data.fill_(0)

    def constrain_parameters(self, **kwargs):
        pass

    def regularization(self):
        """
        Expected L0 norm under the stochastic gates, takes into account and
        re-weights also a potential L2 penalty
        """
        p1 = torch.sigmoid(self.logit_p1)
        weight_decay_ungated = .5 * self.l2_strength * self.weight.pow(2)
        weight_l2_l0 = torch.sum(
            (weight_decay_ungated + self.l0_strength) * p1)
        bias_l2 = (0 if not self.use_bias else torch.sum(.5 *
                                                         self.l2_strength *
                                                         self.bias.pow(2)))
        return -weight_l2_l0 - bias_l2

    def sample_weight(self):
        u = self.floatTensor(self.logit_p1.size()).uniform_(0, 1)

        p1 = torch.sigmoid(self.logit_p1)
        mask = p1 > u

        def cc_to_p1(grad):
            ratio = p1 / (1 - p1)
            p1.backward(grad * torch.where(mask, 1 / ratio, ratio))
            return grad

        z = mask.float()
        z.requires_grad_()
        z.register_hook(cc_to_p1)

        return self.weight * z

    def forward(self, x):
        return F.conv2d(x, self.sample_weight(),
                        (self.bias if self.use_bias else None), self.stride,
                        self.padding, self.dilation, self.groups)

    def get_expected_nonzeros(self):
        expected_gates = torch.sigmoid(self.logit_p1)
        return expected_gates.sum(
            dim=tuple(range(1, len(expected_gates.shape)))).detach()

    def get_inference_nonzeros(self):
        u = self.floatTensor(self.logit_p1.size()).uniform_(0, 1)
        inference_gates = torch.sigmoid(self.logit_p1) > u
        return inference_gates.sum(
            dim=tuple(range(1, len(inference_gates.shape)))).detach()
Beispiel #14
0
class group_relaxed_L0Dense(Module):
    """Implementation of TFL regularization for the input units of a fully connected layer"""
    def __init__(self,
                 in_features,
                 out_features,
                 bias=True,
                 lamba=1.,
                 beta=4.,
                 weight_decay=1.,
                 **kwargs):
        """
		:param in_features: input dimensionality
		:param out_features: output dimensionality
		:param bias: whether we use bias
		:param lamba: strength of the TF1 regularization
		"""
        super(group_relaxed_L0Dense, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(in_features, out_features))
        self.u = torch.rand(in_features, out_features)
        self.u = self.u.to('cuda')
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.lamba = lamba
        self.beta = beta
        self.weight_decay = weight_decay
        self.floatTensor = torch.FloatTensor if not torch.cuda.is_available(
        ) else torch.cuda.FloatTensor
        self.reset_parameters()
        print(self)

    def reset_parameters(self):
        init.kaiming_normal(self.weight, mode='fan_out')

        if self.bias is not None:
            self.bias.data.normal_(0, 1e-2)

    def constrain_parameters(self, **kwargs):
        #self.weight.data = F.normalize(self.weight.data, p=2, dim=1)
        m = Hardshrink((2 * self.lamba / self.beta)**(1 / 2))
        self.u.data = m(self.weight.data)

    def grow_beta(self, growth_factor):
        self.beta = self.beta * growth_factor

    def _reg_w(self, **kwargs):
        logpw = -self.beta * torch.sum(
            0.5 * self.weight.add(-self.u).pow(2)) - self.lamba * np.sqrt(
                self.out_features) * torch.sum(
                    torch.pow(torch.sum(self.weight.pow(2), 1), 0.5))
        logpb = 0
        if self.bias is not None:
            logpb = -torch.sum(self.weight_decay * .5 * (self.bias.pow(2)))
        return logpw + logpb

    def regularization(self):
        return self._reg_w()

    def count_zero_u(self):
        total = np.prod(self.u.size())
        zero = total - self.u.nonzero().size(0)
        return zero

    def count_zero_w(self):
        return torch.sum((self.weight.abs() < 1e-5).int()).item()

    def count_weight(self):
        return np.prod(self.u.size())

    def count_active_neuron(self):
        return torch.sum(
            torch.sum(self.weight.abs() / self.out_features, 1) > 1e-5).item()

    def count_total_neuron(self):
        return self.in_features

    def count_expected_flops_and_l0(self):
        ppos = torch.sum(self.weight.abs() > 0.000001).item()
        expected_flops = (2 * ppos - 1) * self.out_features
        expected_l0 = ppos * self.out_features
        if self.bias is not None:
            expected_flops += self.out_features
            expected_l0 += self.out_features
        return expected_flops, expected_l0

    def forward(self, input):
        output = input.mm(self.weight)
        if self.bias is not None:
            output.add_(self.bias.view(1, self.out_features).expand_as(output))
        return output

    def __repr__(self):
        return self.__class__.__name__+' (' \
         + str(self.in_features) + ' -> ' \
         + str(self.out_features) + ', lambda: ' \
         + str(self.lamba) + ')'
Beispiel #15
0
class L0Dense(Module):
    """Implementation of L0 regularization for the input units of a fully connected layer"""
    def __init__(self,
                 in_features,
                 out_features,
                 bias=True,
                 weight_decay=1.,
                 droprate_init=0.5,
                 temperature=2. / 3.,
                 lamba=1.,
                 local_rep=False,
                 **kwargs):
        """
        :param in_features: Input dimensionality
        :param out_features: Output dimensionality
        :param bias: Whether we use a bias
        :param weight_decay: Strength of the L2 penalty
        :param droprate_init: Dropout rate that the L0 gates will be initialized to
        :param temperature: Temperature of the concrete distribution
        :param lamba: Strength of the L0 penalty
        :param local_rep: Whether we will use a separate gate sample per element in the minibatch
        """
        super(L0Dense, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.prior_prec = weight_decay
        self.weights = Parameter(torch.Tensor(in_features, out_features))
        self.qz_loga = Parameter(torch.Tensor(in_features))
        self.temperature = temperature
        self.droprate_init = droprate_init if droprate_init != 0. else 0.5
        self.lamba = lamba
        self.use_bias = False
        self.local_rep = local_rep
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
            self.use_bias = True
        self.floatTensor = torch.FloatTensor if not torch.cuda.is_available(
        ) else torch.cuda.FloatTensor
        self.reset_parameters()
        print(self)

    def reset_parameters(self):
        init.kaiming_normal(self.weights, mode='fan_out')

        self.qz_loga.data.normal_(
            math.log(1 - self.droprate_init) - math.log(self.droprate_init),
            1e-2)

        if self.use_bias:
            self.bias.data.fill_(0)

    def constrain_parameters(self, **kwargs):
        self.qz_loga.data.clamp_(min=math.log(1e-2), max=math.log(1e2))

    def cdf_qz(self, x):
        """Implements the CDF of the 'stretched' concrete distribution"""
        xn = (x - limit_a) / (limit_b - limit_a)
        logits = math.log(xn) - math.log(1 - xn)
        return torch.sigmoid(logits * self.temperature - self.qz_loga).clamp(
            min=epsilon, max=1 - epsilon)

    def quantile_concrete(self, x):
        """Implements the quantile, aka inverse CDF, of the 'stretched' concrete distribution"""
        y = torch.sigmoid((torch.log(x) - torch.log(1 - x) + self.qz_loga) /
                          self.temperature)
        return y * (limit_b - limit_a) + limit_a

    def _reg_w(self):
        """Expected L0 norm under the stochastic gates, takes into account and re-weights also a potential L2 penalty"""
        logpw_col = torch.sum(
            -(.5 * self.prior_prec * self.weights.pow(2)) - self.lamba, 1)
        logpw = torch.sum((1 - self.cdf_qz(0)) * logpw_col)
        logpb = 0 if not self.use_bias else -torch.sum(.5 * self.prior_prec *
                                                       self.bias.pow(2))
        return logpw + logpb

    def regularization(self):
        return self._reg_w()

    def count_expected_flops_and_l0(self):
        """Measures the expected floating point operations (FLOPs) and the expected L0 norm"""
        # dim_in multiplications and dim_in - 1 additions for each output neuron for the weights
        # + the bias addition for each neuron
        # total_flops = (2 * in_features - 1) * out_features + out_features
        ppos = torch.sum(1 - self.cdf_qz(0))
        expected_flops = (2 * ppos - 1) * self.out_features
        expected_l0 = ppos * self.out_features
        if self.use_bias:
            expected_flops += self.out_features
            expected_l0 += self.out_features
        return expected_flops.item(), expected_l0.item()

    def get_eps(self, size):
        """Uniform random numbers for the concrete distribution"""
        eps = self.floatTensor(size).uniform_(epsilon, 1 - epsilon)
        eps = Variable(eps)
        return eps

    def sample_z(self, batch_size, sample=True):
        """Sample the hard-concrete gates for training and use a deterministic value for testing"""
        if sample:
            eps = self.get_eps(self.floatTensor(batch_size, self.in_features))
            z = self.quantile_concrete(eps)
            return F.hardtanh(z, min_val=0, max_val=1)
        else:  # mode
            pi = torch.sigmoid(self.qz_loga).view(1, self.in_features).expand(
                batch_size, self.in_features)
            return F.hardtanh(pi * (limit_b - limit_a) + limit_a,
                              min_val=0,
                              max_val=1)

    def sample_weights(self):
        z = self.quantile_concrete(
            self.get_eps(self.floatTensor(self.in_features)))
        mask = F.hardtanh(z, min_val=0, max_val=1)
        return mask.view(self.in_features, 1) * self.weights

    def forward(self, input):
        if self.local_rep or not self.training:
            z = self.sample_z(input.size(0), sample=self.training)
            xin = input.mul(z)
            output = xin.mm(self.weights)
        else:
            weights = self.sample_weights()
            output = input.mm(weights)
        if self.use_bias:
            output.add_(self.bias)
        return output

    def __repr__(self):
        s = (
            '{name}({in_features} -> {out_features}, droprate_init={droprate_init}, '
            'lamba={lamba}, temperature={temperature}, weight_decay={prior_prec}, '
            'local_rep={local_rep}')
        if not self.use_bias:
            s += ', bias=False'
        s += ')'
        return s.format(name=self.__class__.__name__, **self.__dict__)
Beispiel #16
0
class CLTLayer(nn.Module):
    def __init__(self,
                 in_features,
                 out_features,
                 alpha=10,
                 isinput=False,
                 isoutput=False):
        super(CLTLayer, self).__init__()
        self.n_in = in_features
        self.n_out = out_features
        self.isoutput = isoutput
        self.isinput = isinput
        self.alpha = alpha

        self.Mbias = nn.Parameter(torch.Tensor(out_features))

        self.M = Parameter(torch.Tensor(out_features, in_features))
        self.logS = nn.Parameter(torch.Tensor(out_features, in_features))
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.M.size(1))
        self.M.data.normal_(0, stdv)
        self.logS.data.zero_().normal_(-9, 0.001)
        self.Mbias.data.zero_()

    def KL(self):
        logS = self.logS.clamp(-11, 11)
        kl = 0.5 * (self.alpha * (self.M.pow(2) + logS.exp()) - logS).sum()
        return kl

    def cdf(self, x, mu=0., sig=1.):
        return 0.5 * (1 + torch.erf((x - mu) / (sig * math.sqrt(2))))

    def pdf(self, x, mu=0., sig=1.):
        return (1 / (math.sqrt(2 * math.pi) * sig)) * torch.exp(-0.5 * (
            (x - mu) / sig).pow(2))

    def relu_moments(self, mu, sig):
        alpha = mu / sig
        cdf = self.cdf(alpha)
        pdf = self.pdf(alpha)
        relu_mean = mu * cdf + sig * pdf
        relu_var = (sig.pow(2) +
                    mu.pow(2)) * cdf + mu * sig * pdf - relu_mean.pow(2)
        return relu_mean, relu_var

    def forward(self, mu_h, var_h):
        M = self.M
        var_s = self.logS.clamp(-11, 11).exp()

        mu_f = F.linear(mu_h, M, self.Mbias)
        # No input variance
        if self.isinput:
            var_f = F.linear(mu_h**2, var_s)
        else:
            var_f = F.linear(var_h + mu_h.pow(2), var_s) + F.linear(
                var_h, M.pow(2))

        # compute relu moments if it is not an output layer
        if not self.isoutput:
            return self.relu_moments(mu_f, var_f.sqrt())
        else:
            return mu_f, var_f

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.n_in) + ' -> ' \
               + str(self.n_out) \
               + f', isinput={self.isinput}, isoutput={self.isoutput})'
Beispiel #17
0
class TDConv2d(Module):
    """Implementation of L0 regularization for the feature maps of a convolutional layer"""
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
                 dropout=0.5,
                 dropout_botk=0.5,
                 dropout_type="weight",
                 temperature=2.0 / 3.0,
                 weight_decay=1.0,
                 lamba=1.0,
                 local_rep=False,
                 **kwargs):
        """
        :param in_channels: Number of input channels
        :param out_channels: Number of output channels
        :param kernel_size: Size of the kernel
        :param stride: Stride for the convolution
        :param padding: Padding for the convolution
        :param dilation: Dilation factor for the convolution
        :param groups: How many groups we will assume in the convolution
        :param bias: Whether we will use a bias
        :param weight_decay: Strength of the L2 penalty
        """
        super(TDConv2d, self).__init__()
        if in_channels % groups != 0:
            raise ValueError("in_channels must be divisible by groups")
        if out_channels % groups != 0:
            raise ValueError("out_channels must be divisible by groups")
        self.weight_decay = weight_decay
        self.floatTensor = (torch.FloatTensor if not torch.cuda.is_available()
                            else torch.cuda.FloatTensor)
        self.prune_rate = 0
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = pair(kernel_size)
        self.stride = pair(stride)
        self.padding = pair(padding)
        self.dilation = pair(dilation)
        self.output_padding = pair(0)
        self.groups = groups
        self.weight = Parameter(
            self.floatTensor(out_channels, in_channels // groups,
                             *self.kernel_size))
        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter("bias", None)

        self.dropout = dropout
        self.dropout_type = dropout_type
        self.dropout_botk = dropout_botk

        self.reset_parameters()
        self.input_shape = None
        print(self)

        print(self)

    def reset_parameters(self):
        init.kaiming_normal(self.weight, mode="fan_in")

        if self.bias is not None:
            self.bias.data.normal_(0, 1e-2)

    def constrain_parameters(self, thres_std=1.0):
        pass

    def _reg_w(self, **kwargs):
        logpw = -torch.sum(self.weight_decay * 0.5 * (self.weight.pow(2)))
        logpb = 0
        if self.bias is not None:
            logpb = -torch.sum(self.weight_decay * 0.5 * (self.bias.pow(2)))
        return logpw + logpb

    def regularization(self):
        return self._reg_w()

    def count_expected_flops_and_l0(self):
        ppos = self.out_channels
        n = self.kernel_size[0] * self.kernel_size[
            1] * self.in_channels  # vector_length
        flops_per_instance = n + (n - 1
                                  )  # (n: multiplications and n-1: additions)

        num_instances_per_filter = (
            (self.input_shape[1] - self.kernel_size[0] + 2 * self.padding[0]) /
            self.stride[0]) + 1  # for rows
        num_instances_per_filter *= (
            (self.input_shape[2] - self.kernel_size[1] + 2 * self.padding[1]) /
            self.stride[1]) + 1  # multiplying with cols

        flops_per_filter = num_instances_per_filter * flops_per_instance
        expected_flops = flops_per_filter * ppos  # multiply with number of filters
        expected_l0 = n * ppos

        if self.bias is not None:
            # since the gate is applied to the output we also reduce the bias computation
            expected_flops += num_instances_per_filter * ppos
            expected_l0 += ppos

        return expected_flops, expected_l0

    def targeted_dropout(self, w):
        drop_rate = self.dropout
        targ_perc = self.dropout_botk

        # print("w_orig: ", w)

        if self.dropout == 0:
            return w

        cuda0 = torch.device("cuda:0")
        if self.dropout_type == "weight":
            w_shape = w.size()
            w = w.view(w_shape[0], -1)
            norm = w.abs()
            idx = int(targ_perc * float(w.size()[1]))
            norm_sorted, _ = norm.sort(dim=1)
            threshold = norm_sorted[:, idx]
            mask = norm < threshold[:, None]

            if not self.training:
                w = (1.0 - mask.float()) * w
                w = w.view(w_shape)
                return w

            dropout_mask = torch.rand(w.size(), device=cuda0) < drop_rate
            mask = dropout_mask & mask
            w = (1.0 - mask.float()) * w
            w = w.view(w_shape)
            return w
        if self.dropout_type == "unit":
            w_shape = w.size()
            w = w.view(w_shape[0], -1)
            idx = int(targ_perc * float(w.size()[0]))
            norm = w.norm(p=2, dim=1)
            norm_sorted, _ = norm.sort(dim=0)
            #print("norm_sorted:", norm_sorted)
            threshold = norm_sorted[idx]
            #print("thresh:", threshold)
            mask = norm < threshold
            #print("mask:", mask, mask.size())
            mask = mask.repeat(1, w.size()[1]).view(w.size()[0], -1)
            #print(mask.size(), w.size(), "yolo")
            dropout_mask = torch.rand(w.size(), device=cuda0) < drop_rate
            mask = dropout_mask & mask
            w = (1.0 - mask.float()) * w
            w = w.view(w_shape)
            return w

    def prune(self, botk):
        self.prune_rate = botk

    def prune_weights(self, w):
        w_shape = w.size()
        w = w.view(-1, w_shape[-1])
        norm = w.abs()
        idx = int(self.prune_rate * float(w.size()[0]))
        norm_sorted, _ = norm.sort(dim=0)
        threshold = norm_sorted[idx:idx + 1]
        mask = norm >= threshold
        w = mask.float() * w
        w = w.view(w_shape)
        return torch.nn.Parameter(w)

    def forward(self, input_):
        if self.input_shape is None:
            self.input_shape = input_.size()
        weight = self.targeted_dropout(self.weight)
        if self.prune_rate > 0.0:
            weight = self.prune_weights(weight)
        output = F.conv2d(input_, weight, self.bias, self.stride, self.padding,
                          self.dilation, self.groups)
        return output

    def __repr__(self):
        s = (
            "{name}({in_channels}, {out_channels}, kernel_size={kernel_size}, stride={stride}, "
            "dropout={dropout}, dropout_botk={dropout_botk}, ")
        if self.padding != (0, ) * len(self.padding):
            s += ", padding={padding}"
        if self.dilation != (1, ) * len(self.dilation):
            s += ", dilation={dilation}"
        if self.output_padding != (0, ) * len(self.output_padding):
            s += ", output_padding={output_padding}"
        if self.groups != 1:
            s += ", groups={groups}"
        s += ")"
        return s.format(name=self.__class__.__name__, **self.__dict__)
Beispiel #18
0
class sparse_group_lasso_Conv2d(Module):
    """Implementation of TF1 regularization for the feature maps of a convolutional layer"""
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
                 lamba=1.,
                 weight_decay=1.,
                 **kwargs):
        """
		:param in_channels: Number of input channels
		:param out_channels: Number of output channels
		:param kernel_size: size of the kernel
		:param stride: stride for the convolution
		:param padding: padding for the convolution
		:param dilation: dilation factor for the convolution
		:param groups: how many groups we will assume in the convolution
		:param bias: whether we will use a bias
		:param lamba: strength of the TFL regularization
		"""
        super(sparse_group_lasso_Conv2d, self).__init__()
        self.floatTensor = torch.FloatTensor if not torch.cuda.is_available(
        ) else torch.cuda.FloatTensor
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = pair(kernel_size)
        self.stride = pair(stride)
        self.padding = pair(padding)
        self.dilation = pair(dilation)
        self.output_padding = pair(0)
        self.groups = groups
        self.lamba = lamba
        self.weight_decay = weight_decay
        self.weight = Parameter(
            torch.Tensor(out_channels, in_channels // groups,
                         *self.kernel_size))
        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()
        self.input_shape = None
        print(self)

    def reset_parameters(self):
        init.kaiming_normal(self.weight, mode='fan_in')

        if self.bias is not None:
            self.bias.data.normal_(0, 1e-2)

    def constrain_parameters(self, thres_std=1.):
        pass

    def _reg_w(self, **kwargs):
        logpw = -self.lamba * np.sqrt(
            self.in_channels * self.kernel_size[0] *
            self.kernel_size[1]) * torch.sum(
                torch.pow(torch.sum(self.weight.pow(2), 3).sum(2).sum(1),
                          0.5)) - torch.sum(self.lamba * self.weight.abs())
        logpb = 0
        if self.bias is not None:
            logpb = -torch.sum(self.weight_decay * .5 * (self.bias.pow(2)))
        return logpw + logpb

    def regularization(self):
        return self._reg_w()

    def count_zero_w(self):
        return torch.sum((self.weight.abs() < 1e-5).int()).item()

    def count_weight(self):
        return np.prod(self.weight.size())

    def count_active_neuron(self):
        return torch.sum((torch.sum(self.weight.abs(), 3).sum(2).sum(1) /
                          (self.in_channels * self.kernel_size[0] *
                           self.kernel_size[1])) > 1e-5).item()

    def count_total_neuron(self):
        return self.out_channels

    def count_expected_flops_and_l0(self):
        #ppos = self.out_channels
        ppos = torch.sum(
            torch.sum(self.weight.abs(), 3).sum(2).sum(1) > 0.001).item()
        n = self.kernel_size[0] * self.kernel_size[1] * self.in_channels
        flops_per_instance = n + (n - 1)

        num_instances_per_filter = (
            (self.input_shape[1] - self.kernel_size[0] + 2 * self.padding[0]) /
            self.stride[0]) + 1
        num_instances_per_filter *= (
            (self.input_shape[2] - self.kernel_size[1] + 2 * self.padding[1]) /
            self.stride[1]) + 1

        flops_per_filter = num_instances_per_filter * flops_per_instance
        expected_flops = flops_per_filter * ppos
        expected_l0 = n * ppos

        if self.bias is not None:
            expected_flops += num_instances_per_filter * ppos
            expected_l0 += ppos
        return expected_flops, expected_l0

    def forward(self, input_):
        if self.input_shape is None:
            self.input_shape = input_.size()
        output = F.conv2d(input_, self.weight, self.bias, self.stride,
                          self.padding, self.dilation, self.groups)
        return output

    def __repr__(self):
        s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size} '
             ', stride={stride}')
        if self.padding != (0, ) * len(self.padding):
            s += ', padding={padding}'
        if self.dilation != (1, ) * len(self.dilation):
            s += ', dilation={dilation}'
        if self.output_padding != (0, ) * len(self.output_padding):
            s += ', output_padding={output_padding}'
        if self.groups != 1:
            s += ', groups={groups}'
        if self.bias is None:
            s += ', bias=False'
        s += ')'
        return s.format(name=self.__class__.__name__, **self.__dict__)
class _ConvNdGroupNJ(BayesianLayers):
    """Convolutional Group Normal-Jeffrey's layers (aka Group Variational Dropout).

    References:
    [1] Kingma, Diederik P., Tim Salimans, and Max Welling. "Variational dropout and the local reparameterization trick." NIPS (2015).
    [2] Molchanov, Dmitry, Arsenii Ashukha, and Dmitry Vetrov. "Variational Dropout Sparsifies Deep Neural Networks." ICML (2017).
    [3] Louizos, Christos, Karen Ullrich, and Max Welling. "Bayesian Compression for Deep Learning." NIPS (2017).
    """

    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding,
                 groups, bias, init_weight, init_bias, cuda=False, clip_var=None):
        super(_ConvNdGroupNJ, self).__init__()
        if in_channels % groups != 0:
            raise ValueError('in_channels must be divisible by groups')
        if out_channels % groups != 0:
            raise ValueError('out_channels must be divisible by groups')
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.transposed = transposed
        self.output_padding = output_padding
        self.groups = groups

        self.cuda = cuda
        self.clip_var = clip_var
        self.deterministic = False  # flag is used for compressed inference

        if transposed:
            self.weight_mu = Parameter(torch.Tensor(
                in_channels, out_channels // groups, *kernel_size))
            self.weight_logvar = Parameter(torch.Tensor(
                in_channels, out_channels // groups, *kernel_size))
        else:
            self.weight_mu = Parameter(torch.Tensor(
                out_channels, in_channels // groups, *kernel_size))
            self.weight_logvar = Parameter(torch.Tensor(
                out_channels, in_channels // groups, *kernel_size))

        self.bias_mu = Parameter(torch.Tensor(out_channels))
        self.bias_logvar = Parameter(torch.Tensor(out_channels))

        self.z_mu = Parameter(torch.Tensor(self.out_channels))
        self.z_logvar = Parameter(torch.Tensor(self.out_channels))

        self.reset_parameters(init_weight, init_bias)

        # activations for kl
        self.sigmoid = nn.Sigmoid()
        self.softplus = nn.Softplus()
        # numerical stability param
        self.epsilon = 1e-8

    def reset_parameters(self, init_weight, init_bias):
        # init means
        n = self.in_channels
        for k in self.kernel_size:
            n *= k
        stdv = 1. / math.sqrt(n)

        # init means
        if init_weight is not None:
            self.weight_mu.data = init_weight
        else:
            self.weight_mu.data.uniform_(-stdv, stdv)

        if init_bias is not None:
            self.bias_mu.data = init_bias
        else:
            self.bias_mu.data.fill_(0)

        # inti z
        self.z_mu.data.normal_(1, 1e-2)

        # init logvars
        self.z_logvar.data.normal_(-9, 1e-2)
        self.weight_logvar.data.normal_(-9, 1e-2)
        self.bias_logvar.data.normal_(-9, 1e-2)

    def clip_variances(self):
        if self.clip_var:
            self.weight_logvar.data.clamp_(max=math.log(self.clip_var))
            self.bias_logvar.data.clamp_(max=math.log(self.clip_var))

    def get_log_dropout_rates(self):
        log_alpha = self.z_logvar - torch.log(self.z_mu.pow(2) + self.epsilon)
        return log_alpha

    def compute_posterior_params(self):
        weight_var, z_var = self.weight_logvar.exp(), self.z_logvar.exp()
        print("self.z_mu.pow(2): ", self.z_mu.pow(2).size())
        print("weight_var: ", weight_var.size())
        print("z_var: ", z_var.size())
        print("self.weight_mu.pow(2): ", self.weight_mu.pow(2).size())
        print("weight_var: ", weight_var.size())
        part1 = self.z_mu.pow(2) * weight_var
        part2 = z_var * self.weight_mu.pow(2)
        part3 = z_var * weight_var
        self.post_weight_var = part1 + part2 + part3
        self.post_weight_mu = self.weight_mu * self.z_mu
        print("post_weight_mu: ", self.post_weight_mu.size())
        print("post_weight_var: ", self.post_weight_var.size())
        return self.post_weight_mu, self.post_weight_var

    def kl_divergence(self):
        # KL(q(z)||p(z))
        # we use the kl divergence approximation given by [2] Eq.(14)
        k1, k2, k3 = 0.63576, 1.87320, 1.48695
        log_alpha = self.get_log_dropout_rates()
        KLD = -torch.sum(k1 * self.sigmoid(k2 + k3 * log_alpha) - 0.5 * self.softplus(-log_alpha) - k1)

        # KL(q(w|z)||p(w|z))
        # we use the kl divergence given by [3] Eq.(8)
        KLD_element = -self.weight_logvar + 0.5 * (self.weight_logvar.exp().pow(2) + self.weight_mu.pow(2)) - 0.5
        KLD += torch.sum(KLD_element)

        # KL bias
        KLD_element = -self.bias_logvar + 0.5 * (self.bias_logvar.exp().pow(2) + self.bias_mu.pow(2)) - 0.5
        KLD += torch.sum(KLD_element)

        return KLD

    def __repr__(self):
        s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}'
             ', stride={stride}')
        if self.padding != (0,) * len(self.padding):
            s += ', padding={padding}'
        if self.dilation != (1,) * len(self.dilation):
            s += ', dilation={dilation}'
        if self.output_padding != (0,) * len(self.output_padding):
            s += ', output_padding={output_padding}'
        if self.groups != 1:
            s += ', groups={groups}'
        if self.bias is None:
            s += ', bias=False'
        s += ')'
        return s.format(name=self.__class__.__name__, **self.__dict__)
class group_relaxed_SCAD_Dense(Module):
	"""Implementation of TFL regularization for the input units of a fully connected layer"""
	def __init__(self, in_features, out_features, bias=True, lamba=1., alpha = 3.7, beta = 4.0, weight_decay=1., **kwargs):
		"""
		:param in_features: input dimensionality
		:param out_features: output dimensionality
		:param bias: whether we use bias
		:param lamba: strength of the TF1 regularization
		"""
		super(group_relaxed_SCAD_Dense,self).__init__()
		self.in_features = in_features
		self.out_features = out_features
		self.weight = Parameter(torch.Tensor(in_features, out_features))
		self.u = torch.rand(in_features, out_features)
		self.u = self.u.to('cuda')
		if bias:
			self.bias = Parameter(torch.Tensor(out_features))
		else:
			self.register_parameter('bias', None)
		self.lamba = lamba
		self.alpha = alpha
		self.beta = beta
		self.lamba1 = self.lamba/self.beta
		self.weight_decay = weight_decay
		self.floatTensor = torch.FloatTensor if not torch.cuda.is_available() else torch.cuda.FloatTensor
		self.reset_parameters()
		print(self)

	def reset_parameters(self):
		init.kaiming_normal(self.weight, mode='fan_out')

		if self.bias is not None:
			self.bias.data.normal_(0,1e-2)


	def constrain_parameters(self, **kwargs):
		self.u = self.weight.clone()
		s = Softshrink(self.lamba1)
		#shrinkage on values with absolute value less than 2*lamba1
		shrink_value = s(self.weight.data)
		self.u[self.weight.abs()<=2*self.lamba1] = shrink_value[self.weight.abs()<=2*self.lamba1]

		#modify values whose absolute values are between 2*lamba1 and alpha*lamba1
		modify_weight = self.weight.data
		modify_weight = ((self.alpha - 1)*modify_weight-modify_weight.sign()*(3.7*self.lamba1))/(self.alpha -2)
		self.u[(self.weight.abs()>2*self.lamba1) & (self.weight.abs()<=self.alpha*self.lamba1)] = modify_weight[(self.weight.abs()>2*self.lamba1) & (self.weight.abs()<=self.alpha*self.lamba1)]


	def grow_beta(self, growth_factor):
		self.beta = self.beta*growth_factor
		self.lamba1 = self.lamba/self.beta

	def _reg_w(self, **kwargs):
		logpw = -self.beta*torch.sum(0.5*self.weight.add(-self.u).pow(2))-self.lamba*np.sqrt(self.out_features)*torch.sum(torch.pow(torch.sum(self.weight.pow(2),1),0.5))
		logpb = 0
		if self.bias is not None:
			logpb = - torch.sum(self.weight_decay * .5 * (self.bias.pow(2)))
		return logpw + logpb

	def regularization(self):
		return self._reg_w()

	def count_zero_u(self):
		total = np.prod(self.u.size())
		zero = total - self.u.nonzero().size(0)
		return zero

	def count_zero_w(self):
		return torch.sum((self.weight.abs()<1e-5).int()).item()

	def count_weight(self):
		return np.prod(self.u.size())

	def count_active_neuron(self):
		return torch.sum(torch.sum(self.weight.abs()/self.out_features,1)>1e-5).item()

	def count_total_neuron(self):
		return self.in_features

	def count_expected_flops_and_l0(self):
		ppos = torch.sum(self.weight.abs()>0.000001).item()
		expected_flops = (2*ppos-1)*self.out_features
		expected_l0 = ppos*self.out_features
		if self.bias is not None:
			expected_flops += self.out_features
			expected_l0 += self.out_features
		return expected_flops, expected_l0

	def forward(self, input):
		output = input.mm(self.weight)
		if self.bias is not None:
			output.add_(self.bias.view(1, self.out_features).expand_as(output))
		return output

	def __repr__(self):
		return self.__class__.__name__+' (' \
			+ str(self.in_features) + ' -> ' \
			+ str(self.out_features) + ', lambda: ' \
			+ str(self.lamba) + ')'
Beispiel #21
0
class MAPConv2d(Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride=1,
        padding=0,
        dilation=1,
        groups=1,
        bias=True,
        weight_decay=1.0,
        **kwargs
    ):
        super(MAPConv2d, self).__init__()
        self.weight_decay = weight_decay
        self.floatTensor = (
            torch.FloatTensor if not torch.cuda.is_available() else torch.cuda.FloatTensor
        )
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = pair(kernel_size)
        self.stride = pair(stride)
        self.padding = pair(padding)
        self.dilation = pair(dilation)
        self.output_padding = pair(0)
        self.groups = groups
        self.weight = Parameter(
            torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)
        )

        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter("bias", None)

        self.reset_parameters()
        self.input_shape = None
        print(self)

    def reset_parameters(self):
        init.kaiming_normal(self.weight, mode="fan_in")

        if self.bias is not None:
            self.bias.data.normal_(0, 1e-2)

    def constrain_parameters(self, thres_std=1.0):
        pass

    def _reg_w(self, **kwargs):
        logpw = -torch.sum(self.weight_decay * 0.5 * (self.weight.pow(2)))
        logpb = 0
        if self.bias is not None:
            logpb = -torch.sum(self.weight_decay * 0.5 * (self.bias.pow(2)))
        return logpw + logpb

    def regularization(self):
        return self._reg_w()

    def count_expected_flops_and_l0(self):
        ppos = self.out_channels
        n = self.kernel_size[0] * self.kernel_size[1] * self.in_channels  # vector_length
        flops_per_instance = n + (n - 1)  # (n: multiplications and n-1: additions)

        num_instances_per_filter = (
            (self.input_shape[1] - self.kernel_size[0] + 2 * self.padding[0]) / self.stride[0]
        ) + 1  # for rows
        num_instances_per_filter *= (
            (self.input_shape[2] - self.kernel_size[1] + 2 * self.padding[1]) / self.stride[1]
        ) + 1  # multiplying with cols

        flops_per_filter = num_instances_per_filter * flops_per_instance
        expected_flops = flops_per_filter * ppos  # multiply with number of filters
        expected_l0 = n * ppos

        if self.bias is not None:
            # since the gate is applied to the output we also reduce the bias computation
            expected_flops += num_instances_per_filter * ppos
            expected_l0 += ppos

        return expected_flops, expected_l0

    def forward(self, input_):
        if self.input_shape is None:
            self.input_shape = input_.size()
        output = F.conv2d(
            input_, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
        )
        return output

    def __repr__(self):
        s = (
            "{name}({in_channels}, {out_channels}, kernel_size={kernel_size} "
            ", stride={stride}, weight_decay={weight_decay}"
        )
        if self.padding != (0,) * len(self.padding):
            s += ", padding={padding}"
        if self.dilation != (1,) * len(self.dilation):
            s += ", dilation={dilation}"
        if self.output_padding != (0,) * len(self.output_padding):
            s += ", output_padding={output_padding}"
        if self.groups != 1:
            s += ", groups={groups}"
        if self.bias is None:
            s += ", bias=False"
        s += ")"
        return s.format(name=self.__class__.__name__, **self.__dict__)
Beispiel #22
0
class L0Conv2d(Module):
    """Implementation of L0 regularization for the feature maps of a convolutional layer"""
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
                 droprate_init=0.5,
                 temperature=2. / 3.,
                 weight_decay=1.,
                 lamba=1.,
                 local_rep=False,
                 **kwargs):
        """
        :param in_channels: Number of input channels
        :param out_channels: Number of output channels
        :param kernel_size: Size of the kernel
        :param stride: Stride for the convolution
        :param padding: Padding for the convolution
        :param dilation: Dilation factor for the convolution
        :param groups: How many groups we will assume in the convolution
        :param bias: Whether we will use a bias
        :param droprate_init: Dropout rate that the L0 gates will be initialized to
        :param temperature: Temperature of the concrete distribution
        :param weight_decay: Strength of the L2 penalty
        :param lamba: Strength of the L0 penalty
        :param local_rep: Whether we will use a separate gate sample per element in the minibatch
        """
        super(L0Conv2d, self).__init__()
        if in_channels % groups != 0:
            raise ValueError('in_channels must be divisible by groups')
        if out_channels % groups != 0:
            raise ValueError('out_channels must be divisible by groups')
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = pair(kernel_size)
        self.stride = pair(stride)
        self.padding = pair(padding)
        self.dilation = pair(dilation)
        self.output_padding = pair(0)
        self.groups = groups
        self.prior_prec = weight_decay
        self.lamba = lamba
        self.droprate_init = droprate_init if droprate_init != 0. else 0.5
        self.temperature = temperature
        self.floatTensor = torch.FloatTensor if not torch.cuda.is_available(
        ) else torch.cuda.FloatTensor
        self.use_bias = False
        self.weights = Parameter(
            torch.Tensor(out_channels, in_channels // groups,
                         *self.kernel_size))
        self.qz_loga = Parameter(torch.Tensor(out_channels))
        self.dim_z = out_channels
        self.input_shape = None
        self.local_rep = local_rep

        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
            self.use_bias = True

        self.reset_parameters()
        print(self)

    def reset_parameters(self):
        init.kaiming_normal(self.weights, mode='fan_in')

        self.qz_loga.data.normal_(
            math.log(1 - self.droprate_init) - math.log(self.droprate_init),
            1e-2)

        if self.use_bias:
            self.bias.data.fill_(0)

    def constrain_parameters(self, **kwargs):
        self.qz_loga.data.clamp_(min=math.log(1e-2), max=math.log(1e2))

    def cdf_qz(self, x):
        """Implements the CDF of the 'stretched' concrete distribution"""
        xn = (x - limit_a) / (limit_b - limit_a)
        logits = math.log(xn) - math.log(1 - xn)
        return torch.sigmoid(logits * self.temperature - self.qz_loga).clamp(
            min=epsilon, max=1 - epsilon)

    def quantile_concrete(self, x):
        """Implements the quantile, aka inverse CDF, of the 'stretched' concrete distribution"""
        y = torch.sigmoid((torch.log(x) - torch.log(1 - x) + self.qz_loga) /
                          self.temperature)
        return y * (limit_b - limit_a) + limit_a

    def _reg_w(self):
        """Expected L0 norm under the stochastic gates, takes into account and re-weights also a potential L2 penalty"""
        q0 = self.cdf_qz(0)
        logpw_col = torch.sum(
            -(.5 * self.prior_prec * self.weights.pow(2)) - self.lamba,
            3).sum(2).sum(1)
        logpw = torch.sum((1 - q0) * logpw_col)
        logpb = 0 if not self.use_bias else -torch.sum(
            (1 - q0) * (.5 * self.prior_prec * self.bias.pow(2) - self.lamba))
        return logpw + logpb

    def regularization(self):
        return self._reg_w()

    def count_expected_flops_and_l0(self):
        """Measures the expected floating point operations (FLOPs) and the expected L0 norm"""
        ppos = torch.sum(1 - self.cdf_qz(0))
        n = self.kernel_size[0] * self.kernel_size[
            1] * self.in_channels  # vector_length
        flops_per_instance = n + (n - 1
                                  )  # (n: multiplications and n-1: additions)

        num_instances_per_filter = (
            (self.input_shape[1] - self.kernel_size[0] + 2 * self.padding[0]) /
            self.stride[0]) + 1  # for rows
        num_instances_per_filter *= (
            (self.input_shape[2] - self.kernel_size[1] + 2 * self.padding[1]) /
            self.stride[1]) + 1  # multiplying with cols

        flops_per_filter = num_instances_per_filter * flops_per_instance
        expected_flops = flops_per_filter * ppos  # multiply with number of filters
        expected_l0 = n * ppos

        if self.use_bias:
            # since the gate is applied to the output we also reduce the bias computation
            expected_flops += num_instances_per_filter * ppos
            expected_l0 += ppos

        return expected_flops.data[0], expected_l0.data[0]

    def get_eps(self, size):
        """Uniform random numbers for the concrete distribution"""
        eps = self.floatTensor(size).uniform_(epsilon, 1 - epsilon)
        eps = Variable(eps)
        return eps

    def sample_z(self, batch_size, sample=True):
        """Sample the hard-concrete gates for training and use a deterministic value for testing"""
        if sample:
            eps = self.get_eps(self.floatTensor(batch_size, self.dim_z))
            z = self.quantile_concrete(eps).view(batch_size, self.dim_z, 1, 1)
            return F.hardtanh(z, min_val=0, max_val=1)
        else:  # mode
            pi = torch.sigmoid(self.qz_loga).view(1, self.dim_z, 1, 1)
            return F.hardtanh(pi * (limit_b - limit_a) + limit_a,
                              min_val=0,
                              max_val=1)

    def sample_weights(self):
        z = self.quantile_concrete(self.get_eps(self.floatTensor(
            self.dim_z))).view(self.dim_z, 1, 1, 1)
        return F.hardtanh(z, min_val=0, max_val=1) * self.weights

    def forward(self, input_):
        if self.input_shape is None:
            self.input_shape = input_.size()
        b = None if not self.use_bias else self.bias
        if self.local_rep or not self.training:
            output = F.conv2d(input_, self.weights, b, self.stride,
                              self.padding, self.dilation, self.groups)
            z = self.sample_z(output.size(0), sample=self.training)
            return output.mul(z)
        else:
            weights = self.sample_weights()
            output = F.conv2d(input_, weights, None, self.stride, self.padding,
                              self.dilation, self.groups)
            return output

    def __repr__(self):
        s = (
            '{name}({in_channels}, {out_channels}, kernel_size={kernel_size}, stride={stride}, '
            'droprate_init={droprate_init}, temperature={temperature}, prior_prec={prior_prec}, '
            'lamba={lamba}, local_rep={local_rep}')
        if self.padding != (0, ) * len(self.padding):
            s += ', padding={padding}'
        if self.dilation != (1, ) * len(self.dilation):
            s += ', dilation={dilation}'
        if self.output_padding != (0, ) * len(self.output_padding):
            s += ', output_padding={output_padding}'
        if self.groups != 1:
            s += ', groups={groups}'
        if not self.use_bias:
            s += ', bias=False'
        s += ')'
        return s.format(name=self.__class__.__name__, **self.__dict__)
Beispiel #23
0
class MAPDense(Module):
    def __init__(self, in_features, out_features, bias=True, weight_decay=1.0, **kwargs):
        super(MAPDense, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(in_features, out_features))
        self.weight_decay = weight_decay
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter("bias", None)
        self.floatTensor = (
            torch.FloatTensor if not torch.cuda.is_available() else torch.cuda.FloatTensor
        )
        self.reset_parameters()
        print(self)

    def reset_parameters(self):
        init.kaiming_normal(self.weight, mode="fan_out")

        if self.bias is not None:
            self.bias.data.normal_(0, 1e-2)

    def constrain_parameters(self, **kwargs):
        pass

    def _reg_w(self, **kwargs):
        logpw = -torch.sum(self.weight_decay * 0.5 * (self.weight.pow(2)))
        logpb = 0
        if self.bias is not None:
            logpb = -torch.sum(self.weight_decay * 0.5 * (self.bias.pow(2)))
        return logpw + logpb

    def regularization(self):
        return self._reg_w()

    def count_expected_flops_and_l0(self):
        # dim_in multiplications and dim_in - 1 additions for each output neuron for the weights
        # + the bias addition for each neuron
        # total_flops = (2 * in_features - 1) * out_features + out_features
        expected_flops = (2 * self.in_features - 1) * self.out_features
        expected_l0 = self.in_features * self.out_features
        if self.bias is not None:
            expected_flops += self.out_features
            expected_l0 += self.out_features
        return expected_flops, expected_l0

    def forward(self, input):
        output = input.mm(self.weight)
        if self.bias is not None:
            output.add_(self.bias.view(1, self.out_features).expand_as(output))
        return output

    def __repr__(self):
        return (
            self.__class__.__name__
            + " ("
            + str(self.in_features)
            + " -> "
            + str(self.out_features)
            + ", weight_decay: "
            + str(self.weight_decay)
            + ")"
        )
Beispiel #24
0
class L0Dense(nn.Module):
    """
    Implementation of L0 regularization for the input units of a fully connected layer
    """
    def __init__(self,
                 feature,
                 embed_dim,
                 weight_decay=0.0005,
                 droprate=0.5,
                 bias=False,
                 temperature=2. / 3.,
                 lamda=1.,
                 local_rep=False,
                 **kwargs):
        """
        feature: input dimension
        embed_dim: output dimension
        bias: whether use a bias
        weight_decay: strength of the L2 penalty
        droprate: dropout rate that the L0 gates will be initialized to
        temperature: temperature of the concrete distribution
        lamda: strength of the L0 penalty
        local_rep: whether use a separate gate sample per element in the minibatch
        """
        super(L0Dense, self).__init__()

        self.feature = feature
        self.embed_dim = embed_dim
        self.prior_prec = weight_decay
        self.temperature = temperature
        self.droprate = droprate
        self.lamda = lamda
        self.use_bias = bias
        self.local_rep = local_rep

        self.weights = Parameter(torch.Tensor(feature, embed_dim))
        # 一行
        self.qz_loga = Parameter(torch.Tensor(feature))
        if bias:
            self.bias = Parameter(torch.Tensor(embed_dim))

        self.floatTensor = torch.FloatTensor if not torch.cuda.is_available(
        ) else torch.cuda.FloatTensor
        self.reset_parameters()

    def reset_parameters(self):
        init.kaiming_normal_(self.weights, mode='fan_out')

        self.qz_loga.data.normal_(
            math.log(1 - self.droprate) - math.log(self.droprate), 1e-2)

        if self.use_bias:
            self.bias.data.fill_(0)

    def constrain_parameters(self, **kwargs):
        self.qz_loga.data.clamp_(min=math.log(1e-2), max=math.log(1e2))

    def cdf_qz(self, x):
        # Implements CDF of the stretched concrete distribution
        xn = (x - limit_a) / (limit_b - limit_a)
        logits = math.log(xn) - math.log(1 - xn)
        return torch.sigmoid(logits * self.temperature - self.qz_loga).clamp(
            min=epsilon, max=1 - epsilon)

    def quantile_concrete(self, x):
        # Implements the quantile of stretched concrete distribution
        y = torch.sigmoid((torch.log(x) - torch.log(1 - x) + self.qz_loga) /
                          self.temperature)
        return y * (limit_b - limit_a) + limit_a

    def _reg_w(self):
        # Expected L0 norm under the stochastic gates
        logpw_col = torch.sum(
            -(.5 * self.prior_prec * self.weights.pow(2)) - self.lamda, 1)
        logpw = torch.sum((1 - self.cdf_qz(0)) * logpw_col)
        logpb = 0 if not self.use_bias else -torch.sum(.5 * self.prior_prec *
                                                       self.bias.pow(2))
        return logpw + logpb

    def regularization(self):
        return self._reg_w()

    def get_eps(self, size):
        # Uniform random numbers for the concrete distribution
        eps = self.floatTensor(size).uniform_(epsilon, 1 - epsilon)
        eps = Variable(eps)
        return eps

    def sample_z(self, batch_size, sample=True):
        # Sample the hard-concrete gates for training and use a deterministic value for testing
        # training
        if sample:
            eps = self.get_eps(self.floatTensor(batch_size, self.feature))
            z = self.quantile_concrete(eps)
            return F.hardtanh(z, min_val=0, max_val=1)
        # testing
        else:
            pi = torch.sigmoid(self.qz_loga).view(1, self.feature).expand(
                batch_size, self.feature)
            return F.hardtanh(pi * (limit_b - limit_a) + limit_a,
                              min_val=0,
                              max_val=1)

    def sample_weights(self):
        z = self.quantile_concrete(self.get_eps(self.floatTensor(
            self.feature)))
        mask = F.hardtanh(z, min_val=0, max_val=1)
        return mask.view(self.feature, 1) * self.weights

    def forward(self, input):
        if self.local_rep or not self.training:
            z = self.sample_z(input.size(0), sample=self.training)
            xin = input.mul(z)
            output = xin.mm(self.weights)
        else:
            weights = self.sample_weights()
            output = input.mm(weights)

        if self.use_bias:
            output.add_(self.bias)
        return output
class HardConcreteGatedLinear(Module):
    """
    Linear layer with stochastic connections, as in
    https://arxiv.org/abs/1712.01312
    """
    def __init__(self,
                 in_features,
                 out_features,
                 l0_strength=1.,
                 l2_strength=1.,
                 bias=True,
                 learn_weight=True,
                 droprate_init=0.5,
                 temperature=(2 / 3),
                 **kwargs):
        """
        :param in_features: Input dimensionality
        :param out_features: Output dimensionality
        :param bias: Whether we use a bias
        :param l2_strength: Strength of the L2 penalty
        :param droprate_init: Dropout rate that the L0 gates will be initialized
                              to
        :param temperature: Temperature of the concrete distribution
        :param l0_strength: Strength of the L0 penalty
        """
        super(HardConcreteGatedLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.l0_strength = l0_strength
        self.l2_strength = l2_strength
        self.floatTensor = (torch.FloatTensor if not torch.cuda.is_available()
                            else torch.cuda.FloatTensor)
        weight = torch.Tensor(out_features, in_features)
        if learn_weight:
            self.weight = Parameter(weight)
        else:
            self.register_buffer("weight", weight)
        self.loga = Parameter(torch.Tensor(out_features, in_features))
        self.temperature = temperature
        self.droprate_init = droprate_init if droprate_init != 0. else 0.5

        if bias:
            bias = torch.Tensor(out_features)
            if learn_weight:
                self.bias = Parameter(bias)
            else:
                self.register_buffer("bias", bias)
            self.use_bias = True
        else:
            self.use_bias = False
        self.reset_parameters()

    def reset_parameters(self):
        init.kaiming_normal_(self.weight, mode="fan_out")
        self.loga.data.normal_(
            math.log(1 - self.droprate_init) - math.log(self.droprate_init),
            1e-2)
        if self.use_bias:
            self.bias.data.fill_(0)

    def constrain_parameters(self, **kwargs):
        self.loga.data.clamp_(min=math.log(1e-2), max=math.log(1e2))

    def cdf_qz(self, x):
        """Implements the CDF of the 'stretched' concrete distribution"""
        xn = (x - LIMIT_A) / (LIMIT_B - LIMIT_A)
        logits = math.log(xn) - math.log(1 - xn)
        return torch.sigmoid(logits * self.temperature - self.loga).clamp(
            min=EPSILON, max=1 - EPSILON)

    def quantile_concrete(self, x):
        """
        Implements the quantile, aka inverse CDF, of the 'stretched' concrete
        distribution
        """
        y = torch.sigmoid(
            (torch.log(x) - torch.log(1 - x) + self.loga) / self.temperature)
        return y * (LIMIT_B - LIMIT_A) + LIMIT_A

    def weight_size(self):
        return self.weight.size()

    def regularization(self):
        """
        Expected L0 norm under the stochastic gates, takes into account and
        re-weights also a potential L2 penalty
        """

        weight_decay_ungated = .5 * self.l2_strength * self.weight.pow(2)
        weight_l2_l0 = torch.sum(
            (weight_decay_ungated + self.l0_strength) * (1 - self.cdf_qz(0)))
        bias_l2 = (0 if not self.use_bias else torch.sum(.5 *
                                                         self.l2_strength *
                                                         self.bias.pow(2)))
        return weight_l2_l0 + bias_l2

    def count_inference_flops(self):
        # For each unit, multiply with its n inputs then do n - 1 additions.
        # To capture the -1, subtract it, but only in cases where there is at
        # least one weight.
        nz_by_unit = self.get_inference_nonzeros()
        multiplies = torch.sum(nz_by_unit)
        adds = multiplies - torch.sum(nz_by_unit > 0)
        return multiplies.item(), adds.item()

    def count_expected_flops_and_l0(self):
        """
        Measures the expected floating point operations (FLOPs) and the expected
        L0 norm

        Copied from the original L0 paper code
        """
        # dim_in multiplications and dim_in - 1 additions for each output unit
        # for the weights # + the bias addition for each unit
        # total_flops = (2 * in_features - 1) * out_features + out_features
        ppos = torch.sum(1 - self.cdf_qz(0))
        expected_flops = (2 * ppos - 1) * self.out_features
        expected_l0 = ppos * self.out_features
        if self.use_bias:
            expected_flops += self.out_features
            expected_l0 += self.out_features
        return expected_flops.data[0], expected_l0.data[0]

    def get_eps(self, size):
        """Uniform random numbers for the concrete distribution"""
        eps = self.floatTensor(size).uniform_(EPSILON, 1 - EPSILON)
        eps = Variable(eps)
        return eps

    def sample_weight(self):
        if self.training:
            z = self.quantile_concrete(
                self.get_eps(self.floatTensor(self.loga.size())))
            mask = F.hardtanh(z, min_val=0, max_val=1)
        else:
            pi = torch.sigmoid(self.loga)
            mask = F.hardtanh(pi * (LIMIT_B - LIMIT_A) + LIMIT_A,
                              min_val=0,
                              max_val=1)

        return mask * self.weight

    def forward(self, x):
        return F.linear(x, self.sample_weight(),
                        (self.bias if self.use_bias else None))

    def get_expected_nonzeros(self):
        expected_gates = 1 - self.cdf_qz(0)
        return expected_gates.sum(
            dim=tuple(range(1, len(expected_gates.shape)))).detach()

    def get_inference_nonzeros(self):
        inference_gates = F.hardtanh(torch.sigmoid(self.loga) *
                                     (LIMIT_B - LIMIT_A) + LIMIT_A,
                                     min_val=0,
                                     max_val=1)
        return (inference_gates > 0).sum(
            dim=tuple(range(1, len(inference_gates.shape)))).detach()

    def __repr__(self):
        s = ("{name}({in_features} -> {out_features}, "
             "droprate_init={droprate_init}, l0_strength={l0_strength}, "
             "temperature={temperature}, l2_strength={l2_strength}, ")
        if not self.use_bias:
            s += ", bias=False"
        s += ")"
        return s.format(name=self.__class__.__name__, **self.__dict__)
class Conv2d(nn.Module):
    def __init__(
            self,
            in_channels,
            out_channels,
            kernel_width=(1, 1),
            stride=(1, 1),
            dilation=(1, 1),
            g_init=1.0,
            bias_init=0.1,
            causal=False,
            activation=None,
    ):
        super(Conv2d, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_width = kernel_width
        self.stride = stride
        self.dilation = dilation
        self.causal = causal
        self.activation = activation

        self.generating = False
        self.generating_reset = True
        self._weight = None
        self._input_cache = None

        self.padding = tuple(d * (w-1)//2 for w, d in zip(kernel_width, dilation))

        self.bias = Parameter(torch.Tensor(out_channels))
        self.weight_v = Parameter(torch.Tensor(out_channels, in_channels, *kernel_width))
        self.weight_g = Parameter(torch.Tensor(out_channels))

        if causal:
            if any(w % 2 == 0 for w in kernel_width):
                raise HyperparameterError(f"Even kernel width incompatible with causal convolution: {kernel_width}")
            if kernel_width == (1, 3):  # make common case explicit
                mask = torch.Tensor([1., 1., 0.])
            elif kernel_width[0] == 1:
                mask = torch.ones(kernel_width)
                mask[0, kernel_width[1] // 2 + 1:] = 0
            else:
                mask = torch.ones(kernel_width)
                mask[kernel_width[0] // 2, kernel_width[1] // 2:] = 0
                mask[kernel_width[0] // 2 + 1:, :] = 0

            mask = mask.view(1, 1, *kernel_width)
            self.register_buffer('mask', mask)
        else:
            self.register_buffer('mask', None)

        self.reset_parameters(g_init=g_init, bias_init=bias_init)

    def reset_parameters(self, v_mean=0., v_std=0.05, g_init=1.0, bias_init=0.1):
        nn.init.normal_(self.weight_v, mean=v_mean, std=v_std)
        nn.init.constant_(self.weight_g, val=g_init)
        nn.init.constant_(self.bias, val=bias_init)

    def generate(self, mode=True):
        self.generating = mode
        self.generating_reset = True
        self._weight = None
        self._input_cache = None
        return self

    def weight_costs(self):
        return (
            self.weight_v.pow(2).sum(),
            self.weight_g.pow(2).sum(),
            self.bias.pow(2).sum()
        )

    @property
    def weight(self):
        shape = (self.out_channels, 1, 1, 1)
        weight = l2_norm_except_dim(self.weight_v, 0) * self.weight_g.view(shape)
        if self.mask is not None:
            weight = weight * self.mask
        return weight

    def forward(self, inputs):
        """
        :param inputs: (N, C_in, H, W)
        :return: (N, C_out, H, W)
        """
        if self.generating:
            if self.generating_reset:
                self.generating_reset = False
                if self.kernel_width != (1, 1):
                    self._input_cache = inputs
            else:
                return self.forward_generate(inputs)

        h = F.conv2d(inputs, self.weight, bias=self.bias,
                     stride=self.stride, padding=self.padding, dilation=self.dilation)
        if self.activation is not None:
            h = self.activation(h)
        return h

    def forward_generate(self, inputs):
        """Calculates forward for the last position in `inputs`
        Only implemented for kernel widths (1, 1) and (1, 3) and stride (1, 1).
        If the kernel width is (1, 3), causal must be True.

        :param inputs: tensor(N, C_in, 1, 1)
        :return: tensor(N, C_out, 1, 1)
        """
        if self._weight is None:
            self._weight = self.weight
            self._weight = self._weight.transpose(0, 1)
        if self.kernel_width == (1, 1):
            h = inputs[:, :, 0, -1] @ self._weight[:, :, 0, 0] + self.bias.view(1, self.out_channels)
        elif self.kernel_width == (1, 3):
            h = inputs[:, :, 0, -1] @ self._weight[:, :, 0, 1]
            if self.dilation[1] < self._input_cache.size(3):
                h += self._input_cache[:, :, 0, -self.dilation[1]] @ self._weight[:, :, 0, 0]
            h += self.bias.view(1, self.out_channels)
            self._input_cache = torch.cat([self._input_cache, inputs], dim=3)
        else:
            raise HyperparameterError(f"Generate not supported for kernel width {self.kernel_width}.")
        if self.activation is not None:
            h = self.activation(h)
        return h.unsqueeze(-1).unsqueeze(-1)

    def extra_repr(self):
        s = '{in_channels}, {out_channels}, kernel_size={kernel_width}'
        if self.stride != (1,) * len(self.stride):
            s += ', stride={stride}'
        if self.dilation != (1,) * len(self.dilation):
            s += ', dilation={dilation}'
        if self.causal:
            s += ', causal=True'
        return s.format(**self.__dict__)
Beispiel #27
0
class BinaryGatedLinear(Module):
    """
    Linear layer with stochastic binary gates
    """
    def __init__(self,
                 in_features,
                 out_features,
                 l0_strength=1.,
                 l2_strength=1.,
                 learn_weight=True,
                 bias=True,
                 droprate_init=0.5,
                 **kwargs):
        """
        :param in_features: Input dimensionality
        :param out_features: Output dimensionality
        :param bias: Whether we use a bias
        :param l2_strength: Strength of the L2 penalty
        :param droprate_init: Dropout rate that the gates will be initialized to
        :param l0_strength: Strength of the L0 penalty
        """
        super(BinaryGatedLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.l0_strength = l0_strength
        self.l2_strength = l2_strength
        self.floatTensor = (torch.FloatTensor if not torch.cuda.is_available()
                            else torch.cuda.FloatTensor)
        weight = torch.Tensor(out_features, in_features)
        if learn_weight:
            self.weight = Parameter(weight)
        else:
            self.register_buffer("weight", weight)

        self.logit_p1 = Parameter(torch.Tensor(out_features, in_features))
        self.droprate_init = droprate_init if droprate_init != 0. else 0.5
        self.use_bias = False
        if bias:
            b = torch.Tensor(out_features)
            if learn_weight:
                self.bias = Parameter(b)
            else:
                self.register_buffer("bias", b)
            self.use_bias = True
        self.reset_parameters()

    def reset_parameters(self):
        init.kaiming_normal_(self.weight, mode="fan_out")
        self.logit_p1.data.normal_(
            math.log(1 - self.droprate_init) - math.log(self.droprate_init),
            1e-2)
        if self.use_bias:
            self.bias.data.fill_(0)

    def constrain_parameters(self, **kwargs):
        pass

    def regularization(self):
        """
        Expected L0 norm under the stochastic gates, takes into account and
        re-weights also a potential L2 penalty
        """
        p1 = torch.sigmoid(self.logit_p1)
        weight_decay_ungated = .5 * self.l2_strength * self.weight.pow(2)
        weight_l2_l0 = torch.sum(
            (weight_decay_ungated + self.l0_strength) * p1)
        bias_l2 = (0 if not self.use_bias else torch.sum(.5 *
                                                         self.l2_strength *
                                                         self.bias.pow(2)))
        return -weight_l2_l0 - bias_l2

    def sample_weight(self):
        u = self.floatTensor(self.logit_p1.size()).uniform_(0, 1)

        p1 = torch.sigmoid(self.logit_p1)
        mask = p1 > u

        def cc_to_p1(grad):
            ratio = p1 / (1 - p1)
            p1.backward(grad * torch.where(mask, 1 / ratio, ratio))
            return grad

        z = mask.float()
        z.requires_grad_()
        z.register_hook(cc_to_p1)

        return self.weight * z

    def forward(self, x):
        return F.linear(x, self.sample_weight(),
                        (self.bias if self.use_bias else None))

    def get_expected_nonzeros(self):
        expected_gates = torch.sigmoid(self.logit_p1)
        return expected_gates.sum(
            dim=tuple(range(1, len(expected_gates.shape)))).detach()

    def get_inference_nonzeros(self):
        u = self.floatTensor(self.logit_p1.size()).uniform_(0, 1)
        inference_gates = torch.sigmoid(self.logit_p1) > u
        return inference_gates.sum(
            dim=tuple(range(1, len(inference_gates.shape)))).detach()
Beispiel #28
0
class group_relaxed_L1L2Conv2d(Module):
    """Implementation of TF1 regularization for the feature maps of a convolutional layer"""
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
                 lamba=1.,
                 alpha=1.,
                 beta=4.,
                 weight_decay=1.,
                 **kwargs):
        """
		:param in_channels: Number of input channels
		:param out_channels: Number of output channels
		:param kernel_size: size of the kernel
		:param stride: stride for the convolution
		:param padding: padding for the convolution
		:param dilation: dilation factor for the convolution
		:param groups: how many groups we will assume in the convolution
		:param bias: whether we will use a bias
		:param lamba: strength of the TFL regularization
		"""
        super(group_relaxed_L1L2Conv2d, self).__init__()
        self.floatTensor = torch.FloatTensor if not torch.cuda.is_available(
        ) else torch.cuda.FloatTensor
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = pair(kernel_size)
        self.stride = pair(stride)
        self.padding = pair(padding)
        self.dilation = pair(dilation)
        self.output_padding = pair(0)
        self.groups = groups
        self.lamba = lamba
        self.alpha = alpha
        self.beta = beta
        self.lamba1 = self.lamba / self.beta
        self.weight_decay = weight_decay
        self.weight = Parameter(
            torch.Tensor(out_channels, in_channels // groups,
                         *self.kernel_size))
        self.u = torch.rand(out_channels, in_channels // groups,
                            *self.kernel_size)
        self.u = self.u.to('cuda')
        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()
        self.input_shape = None
        print(self)

    def reset_parameters(self):
        init.kaiming_normal(self.weight, mode='fan_in')

        if self.bias is not None:
            self.bias.data.normal_(0, 1e-2)

    def constrain_parameters(self, **kwargs):
        norm_w = self.weight.data.norm(p=float('inf'))
        if norm_w > self.lamba1:
            m = Softshrink(self.lamba1)
            z = m(self.weight.data)
            self.u.data = z * (z.data.norm(p=2) +
                               self.alpha * self.lamba1) / (z.data.norm(p=2))
        elif norm_w == self.lamba1:
            self.u = self.weight.clone()
            self.u[self.u.abs() < lamba1] = 0
            n = torch.sum(self.u != 0)
            self.u[self.u != 0] = self.weight.sign(
            ) * self.alpha * self.lamba1 / (n**(1 / 2))

        elif (1 - self.alpha) * self.lamba1 < norm_w and norm_w < self.lamba1:
            self.u = self.weight.clone()
            max_idx = np.unravel_index(torch.argmax(self.u.cpu(), None),
                                       self.u.shape)
            max_value_sign = self.u[max_idx].sign()
            self.u[:] = 0
            self.u[max_idx] = (norm_w +
                               (self.alpha - 1) * self.lamba1) * max_value_sign
        else:
            self.u = self.weight.clone()
            self.u[:] = 0

    def grow_beta(self, growth_factor):
        self.beta = self.beta * growth_factor
        self.lamba1 = self.lamba / self.beta

    def _reg_w(self, **kwargs):
        logpw = -self.beta * torch.sum(
            0.5 * self.weight.add(-self.u).pow(2)) - self.lamba * np.sqrt(
                self.in_channels * self.kernel_size[0] * self.kernel_size[1]
            ) * torch.sum(
                torch.pow(torch.sum(self.weight.pow(2), 3).sum(2).sum(1), 0.5))
        logpb = 0
        if self.bias is not None:
            logpb = -torch.sum(self.weight_decay * .5 * (self.bias.pow(2)))
        return logpw + logpb

    def regularization(self):
        return self._reg_w()

    def count_zero_u(self):
        total = np.prod(self.u.size())
        zero = total - self.u.nonzero().size(0)
        return zero

    def count_zero_w(self):
        return torch.sum((self.weight.abs() < 1e-5).int()).item()

    def count_active_neuron(self):
        return torch.sum((torch.sum(self.weight.abs(), 3).sum(2).sum(1) /
                          (self.in_channels * self.kernel_size[0] *
                           self.kernel_size[1])) > 1e-5).item()

    def count_total_neuron(self):
        return self.out_channels

    def count_weight(self):
        return np.prod(self.u.size())

    def count_expected_flops_and_l0(self):
        #ppos = self.out_channels
        ppos = torch.sum(
            torch.sum(self.weight.abs(), 3).sum(2).sum(1) > 0.001).item()
        n = self.kernel_size[0] * self.kernel_size[1] * self.in_channels
        flops_per_instance = n + (n - 1)

        num_instances_per_filter = (
            (self.input_shape[1] - self.kernel_size[0] + 2 * self.padding[0]) /
            self.stride[0]) + 1
        num_instances_per_filter *= (
            (self.input_shape[2] - self.kernel_size[1] + 2 * self.padding[1]) /
            self.stride[1]) + 1

        flops_per_filter = num_instances_per_filter * flops_per_instance
        expected_flops = flops_per_filter * ppos
        expected_l0 = n * ppos

        if self.bias is not None:
            expected_flops += num_instances_per_filter * ppos
            expected_l0 += ppos
        return expected_flops, expected_l0

    def forward(self, input_):
        if self.input_shape is None:
            self.input_shape = input_.size()
        output = F.conv2d(input_, self.weight, self.bias, self.stride,
                          self.padding, self.dilation, self.groups)
        return output

    def __repr__(self):
        s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size} '
             ', stride={stride}')
        if self.padding != (0, ) * len(self.padding):
            s += ', padding={padding}'
        if self.dilation != (1, ) * len(self.dilation):
            s += ', dilation={dilation}'
        if self.output_padding != (0, ) * len(self.output_padding):
            s += ', output_padding={output_padding}'
        if self.groups != 1:
            s += ', groups={groups}'
        if self.bias is None:
            s += ', bias=False'
        s += ')'
        return s.format(name=self.__class__.__name__, **self.__dict__)
Beispiel #29
0
class LinearGroupNJ(Module):
    """Fully Connected Group Normal-Jeffrey's layer (aka Group Variational Dropout).

    References:
    [1] Kingma, Diederik P., Tim Salimans, and Max Welling. "Variational dropout and the local reparameterization trick." NIPS (2015).
    [2] Molchanov, Dmitry, Arsenii Ashukha, and Dmitry Vetrov. "Variational Dropout Sparsifies Deep Neural Networks." ICML (2017).
    [3] Louizos, Christos, Karen Ullrich, and Max Welling. "Bayesian Compression for Deep Learning." NIPS (2017).
    """

    def __init__(self, in_features, out_features, cuda=False, init_weight=None, init_bias=None, clip_var=None):

        super(LinearGroupNJ, self).__init__()
        self.cuda = cuda
        self.in_features = in_features
        self.out_features = out_features
        self.clip_var = clip_var
        self.deterministic = False  # flag is used for compressed inference
        # trainable params according to Eq.(6)
        # dropout params
        self.z_mu = Parameter(torch.Tensor(in_features))
        self.z_logvar = Parameter(torch.Tensor(in_features))  # = z_mu^2 * alpha
        # weight params
        self.weight_mu = Parameter(torch.Tensor(out_features, in_features))
        self.weight_logvar = Parameter(torch.Tensor(out_features, in_features))

        self.bias_mu = Parameter(torch.Tensor(out_features))
        self.bias_logvar = Parameter(torch.Tensor(out_features))

        # init params either random or with pretrained net
        self.reset_parameters(init_weight, init_bias)

        # activations for kl
        self.sigmoid = nn.Sigmoid()
        self.softplus = nn.Softplus()

        # numerical stability param
        self.epsilon = 1e-8

    def reset_parameters(self, init_weight, init_bias):
        # init means
        stdv = 1. / math.sqrt(self.weight_mu.size(1))

        self.z_mu.data.normal_(1, 1e-2)

        if init_weight is not None:
            self.weight_mu.data = torch.Tensor(init_weight)
        else:
            self.weight_mu.data.normal_(0, stdv)

        if init_bias is not None:
            self.bias_mu.data = torch.Tensor(init_bias)
        else:
            self.bias_mu.data.fill_(0)

        # init logvars
        self.z_logvar.data.normal_(-9, 1e-2)
        self.weight_logvar.data.normal_(-9, 1e-2)
        self.bias_logvar.data.normal_(-9, 1e-2)

    def clip_variances(self):
        if self.clip_var:
            self.weight_logvar.data.clamp_(max=math.log(self.clip_var))
            self.bias_logvar.data.clamp_(max=math.log(self.clip_var))

    def get_log_dropout_rates(self):
        log_alpha = self.z_logvar - torch.log(self.z_mu.pow(2) + self.epsilon)
        return log_alpha

    def compute_posterior_params(self):
        weight_var, z_var = self.weight_logvar.exp(), self.z_logvar.exp()
        self.post_weight_var = self.z_mu.pow(2) * weight_var + z_var * self.weight_mu.pow(2) + z_var * weight_var
        self.post_weight_mu = self.weight_mu * self.z_mu
        return self.post_weight_mu, self.post_weight_var

    def forward(self, x):
        if self.deterministic:
            assert self.training == False, "Flag deterministic is True. This should not be used in training."
            return F.linear(x, self.post_weight_mu, self.bias_mu)

        batch_size = x.size()[0]
        # compute z  
        # note that we reparametrise according to [2] Eq. (11) (not [1])
        z = reparametrize(self.z_mu.repeat(batch_size, 1), self.z_logvar.repeat(batch_size, 1), sampling=self.training)

        # apply local reparametrisation trick see [1] Eq. (6)
        # to the parametrisation given in [3] Eq. (6)
        xz = x * z
        mu_activations = F.linear(xz, self.weight_mu, self.bias_mu)
        var_activations = F.linear(xz.pow(2), self.weight_logvar.exp(), self.bias_logvar.exp())

        return reparametrize(mu_activations, var_activations.log(), sampling=self.training)

    def kl_divergence(self):
        # KL(q(z)||p(z))
        # we use the kl divergence approximation given by [2] Eq.(14)
        k1, k2, k3 = 0.63576, 1.87320, 1.48695
        log_alpha = self.get_log_dropout_rates()
        KLD = -torch.sum(k1 * self.sigmoid(k2 + k3 * log_alpha) - 0.5 * self.softplus(-log_alpha) - k1)

        # KL(q(w|z)||p(w|z))
        # we use the kl divergence given by [3] Eq.(8)
        KLD_element = -0.5 * self.weight_logvar + 0.5 * (self.weight_logvar.exp() + self.weight_mu.pow(2)) - 0.5
        KLD += torch.sum(KLD_element)

        # KL bias
        KLD_element = -0.5 * self.bias_logvar + 0.5 * (self.bias_logvar.exp() + self.bias_mu.pow(2)) - 0.5
        KLD += torch.sum(KLD_element)

        return KLD

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'
class LinearGroupNJ(BayesianLayers):
    """Fully Connected Group Normal-Jeffrey's layer (aka Group Variational Dropout).

    References:
    [1] Kingma, Diederik P., Tim Salimans, and Max Welling. "Variational dropout and the local reparameterization trick." NIPS (2015).
    [2] Molchanov, Dmitry, Arsenii Ashukha, and Dmitry Vetrov. "Variational Dropout Sparsifies Deep Neural Networks." ICML (2017).
    [3] Louizos, Christos, Karen Ullrich, and Max Welling. "Bayesian Compression for Deep Learning." NIPS (2017).
    """

    def __init__(self, in_features, out_features, cuda=False, init_weight=None, init_bias=None, clip_var=None):

        super(LinearGroupNJ, self).__init__()
        self.cuda = cuda
        self.in_features = in_features
        self.out_features = out_features
        self.clip_var = clip_var
        self.deterministic = False  # flag is used for compressed inference
        # trainable params according to Eq.(6)
        # dropout params
        self.z_mu = Parameter(torch.Tensor(in_features))
        self.z_logvar = Parameter(torch.Tensor(in_features))  # = z_mu^2 * alpha
        # weight params
        self.weight_mu = Parameter(torch.Tensor(out_features, in_features))
        self.weight_logvar = Parameter(torch.Tensor(out_features, in_features))

        self.bias_mu = Parameter(torch.Tensor(out_features))
        self.bias_logvar = Parameter(torch.Tensor(out_features))

        # init params either random or with pretrained net
        self.reset_parameters(init_weight, init_bias)

        # activations for kl
        self.sigmoid = nn.Sigmoid()
        self.softplus = nn.Softplus()

        # numerical stability param
        self.epsilon = 1e-8

    def reset_parameters(self, init_weight, init_bias):
        # init means
        stdv = 1. / math.sqrt(self.weight_mu.size(1))

        self.z_mu.data.normal_(1, 1e-2)

        if init_weight is not None:
            self.weight_mu.data = torch.Tensor(init_weight)
        else:
            self.weight_mu.data.normal_(0, stdv)

        if init_bias is not None:
            self.bias_mu.data = torch.Tensor(init_bias)
        else:
            self.bias_mu.data.fill_(0)

        # init logvars
        self.z_logvar.data.normal_(-9, 1e-2)
        self.weight_logvar.data.normal_(-9, 1e-2)
        self.bias_logvar.data.normal_(-9, 1e-2)

    def clip_variances(self):
        if self.clip_var:
            self.weight_logvar.data.clamp_(max=math.log(self.clip_var))
            self.bias_logvar.data.clamp_(max=math.log(self.clip_var))

    def get_log_dropout_rates(self):
        log_alpha = self.z_logvar - torch.log(self.z_mu.pow(2) + self.epsilon)
        return log_alpha

    def compute_posterior_params(self):
        weight_var, z_var = self.weight_logvar.exp(), self.z_logvar.exp()
        self.post_weight_var = self.z_mu.pow(2) * weight_var + z_var * self.weight_mu.pow(2) + z_var * weight_var
        self.post_weight_mu = self.weight_mu * self.z_mu
        # print("self.z_mu.pow(2): ", self.z_mu.pow(2).size())
        # print("weight_var: ", weight_var.size())
        # print("z_var: ", z_var.size())
        # print("self.weight_mu.pow(2): ", self.weight_mu.pow(2).size())
        # print("weight_var: ", weight_var.size())
        # print("post_weight_mu: ", self.post_weight_mu.size())
        # print("post_weight_var: ", self.post_weight_var.size())
        return self.post_weight_mu, self.post_weight_var

    def forward(self, x):
        if self.deterministic:
            assert self.training == False, "Flag deterministic is True. This should not be used in training."
            return F.linear(x, self.post_weight_mu, self.bias_mu)

        batch_size = x.size()[0]
        # compute z
        # note that we reparametrise according to [2] Eq. (11) (not [1])
        z = reparametrize(self.z_mu.repeat(batch_size, 1), self.z_logvar.repeat(batch_size, 1), sampling=self.training,
                          cuda=self.cuda)

        # apply local reparametrisation trick see [1] Eq. (6)
        # to the parametrisation given in [3] Eq. (6)
        xz = x * z
        mu_activations = F.linear(xz, self.weight_mu, self.bias_mu)
        var_activations = F.linear(xz.pow(2), self.weight_logvar.exp(), self.bias_logvar.exp())

        return reparametrize(mu_activations, var_activations.log(), sampling=self.training, cuda=self.cuda)

    def kl_divergence(self):
        # KL(q(z)||p(z))
        # we use the kl divergence approximation given by [2] Eq.(14)
        k1, k2, k3 = 0.63576, 1.87320, 1.48695
        log_alpha = self.get_log_dropout_rates()
        KLD = -torch.sum(k1 * self.sigmoid(k2 + k3 * log_alpha) - 0.5 * self.softplus(-log_alpha) - k1)

        # KL(q(w|z)||p(w|z))
        # we use the kl divergence given by [3] Eq.(8)
        KLD_element = -0.5 * self.weight_logvar + 0.5 * (self.weight_logvar.exp() + self.weight_mu.pow(2)) - 0.5
        KLD += torch.sum(KLD_element)

        # KL bias
        KLD_element = -0.5 * self.bias_logvar + 0.5 * (self.bias_logvar.exp() + self.bias_mu.pow(2)) - 0.5
        KLD += torch.sum(KLD_element)

        return KLD

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'
Beispiel #31
0
class group_relaxed_L1L2Dense(Module):
    """Implementation of TFL regularization for the input units of a fully connected layer"""
    def __init__(self,
                 in_features,
                 out_features,
                 bias=True,
                 lamba=1.,
                 alpha=1.,
                 beta=4.,
                 weight_decay=1.,
                 **kwargs):
        """
		:param in_features: input dimensionality
		:param out_features: output dimensionality
		:param bias: whether we use bias
		:param lamba: strength of the TF1 regularization
		"""
        super(group_relaxed_L1L2Dense, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(in_features, out_features))
        self.u = torch.rand(in_features, out_features)
        self.u = self.u.to('cuda')
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.lamba = lamba
        self.alpha = alpha
        self.beta = beta
        self.lamba1 = self.lamba / self.beta
        self.weight_decay = weight_decay
        self.floatTensor = torch.FloatTensor if not torch.cuda.is_available(
        ) else torch.cuda.FloatTensor
        self.reset_parameters()
        print(self)

    def reset_parameters(self):
        init.kaiming_normal(self.weight, mode='fan_out')

        if self.bias is not None:
            self.bias.data.normal_(0, 1e-2)

    def constrain_parameters(self, **kwargs):
        norm_w = self.weight.data.norm(p=float('inf'))
        if norm_w > self.lamba1:
            m = Softshrink(self.lamba1)
            z = m(self.weight.data)
            self.u.data = z * (z.data.norm(p=2) +
                               self.alpha * self.lamba1) / (z.data.norm(p=2))
        elif norm_w == self.lamba1:
            self.u = self.weight.clone()
            self.u[self.u.abs() < lamba1] = 0
            n = torch.sum(self.u != 0)
            self.u[self.u != 0] = self.weight.sign(
            ) * self.alpha * self.lamba1 / (n**(1 / 2))

        elif (1 - self.alpha) * self.lamba1 < norm_w and norm_w < self.lamba1:
            self.u = self.weight.clone()
            max_idx = np.unravel_index(torch.argmax(self.u.cpu(), None),
                                       self.u.shape)
            max_value_sign = self.u[max_idx].sign()
            self.u[:] = 0
            self.u[max_idx] = (norm_w +
                               (self.alpha - 1) * self.lamba1) * max_value_sign
        else:
            self.u = self.weight.clone()
            self.u[:] = 0

    def grow_beta(self, growth_factor):
        self.beta = self.beta * growth_factor
        self.lamba1 = self.lamba / self.beta

    def _reg_w(self, **kwargs):
        logpw = -self.beta * torch.sum(
            0.5 * self.weight.add(-self.u).pow(2)) - self.lamba * np.sqrt(
                self.out_features) * torch.sum(
                    torch.pow(torch.sum(self.weight.pow(2), 1), 0.5))
        logpb = 0
        if self.bias is not None:
            logpb = -torch.sum(self.weight_decay * .5 * (self.bias.pow(2)))
        return logpw + logpb

    def regularization(self):
        return self._reg_w()

    def count_zero_u(self):
        total = np.prod(self.u.size())
        zero = total - self.u.nonzero().size(0)
        return zero

    def count_zero_w(self):
        return torch.sum((self.weight.abs() < 1e-5).int()).item()

    def count_weight(self):
        return np.prod(self.u.size())

    def count_active_neuron(self):
        return torch.sum(
            torch.sum(self.weight.abs() / self.out_features, 1) > 1e-5).item()

    def count_total_neuron(self):
        return self.in_features

    def count_expected_flops_and_l0(self):
        ppos = torch.sum(self.weight.abs() > 0.000001).item()
        expected_flops = (2 * ppos - 1) * self.out_features
        expected_l0 = ppos * self.out_features
        if self.bias is not None:
            expected_flops += self.out_features
            expected_l0 += self.out_features
        return expected_flops, expected_l0

    def forward(self, input):
        output = input.mm(self.weight)
        if self.bias is not None:
            output.add_(self.bias.view(1, self.out_features).expand_as(output))
        return output

    def __repr__(self):
        return self.__class__.__name__+' (' \
         + str(self.in_features) + ' -> ' \
         + str(self.out_features) + ', lambda: ' \
         + str(self.lamba) + ')'
class HSDense(Module):
    def __init__(self,
                 in_features,
                 out_features,
                 bias=True,
                 prior_std=1.,
                 prior_std_z=1.,
                 dof=1.,
                 **kwargs):
        super(HSDense, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.prior_std = prior_std
        self.mean_w = Parameter(torch.Tensor(in_features, out_features))
        self.logvar_w = Parameter(torch.Tensor(in_features, out_features))
        self.qz_mean = Parameter(torch.Tensor(in_features))
        self.qz_logvar = Parameter(torch.Tensor(in_features))
        self.dof = dof
        self.prior_std_z = prior_std_z
        self.use_bias = False
        if bias:
            self.mean_bias = Parameter(torch.Tensor(out_features))
            self.logvar_bias = Parameter(torch.Tensor(out_features))
            self.use_bias = True
        self.floatTensor = torch.FloatTensor if not torch.cuda.is_available(
        ) else torch.cuda.FloatTensor
        self.reset_parameters()
        print(self)

    def reset_parameters(self):
        init.kaiming_normal_(self.mean_w, mode='fan_out')
        self.logvar_w.data.normal_(-9., 1e-4)

        self.qz_mean.data.normal_(math.log(math.exp(1) - 1), 1e-3)
        self.qz_logvar.data.normal_(math.log(0.1), 1e-4)
        if self.use_bias:
            self.mean_bias.data.normal_(0, 1e-2)
            self.logvar_bias.data.normal_(-9., 1e-4)

    def constrain_parameters(self, thres_std=1.):
        self.logvar_w.data.clamp_(max=2. * math.log(thres_std))
        if self.use_bias:
            self.logvar_bias.data.clamp_(max=2. * math.log(thres_std))

    def eq_logpw(self):
        logpw = -.5 * math.log(
            2 * math.pi * self.prior_std**2) - .5 * self.logvar_w.exp().div(
                self.prior_std**2)
        logpw -= .5 * self.mean_w.pow(2).div(self.prior_std**2)
        logpb = 0.
        if self.use_bias:
            logpb = - .5 * math.log(2 * math.pi * self.prior_std ** 2) - .5 * self.logvar_bias.exp().div \
                (self.prior_std ** 2)
            logpb -= .5 * self.mean_bias.pow(2).div(self.prior_std**2)
        return torch.sum(logpw) + torch.sum(logpb)

    def eq_logqw(self):
        logqw = -torch.sum(.5 * (math.log(2 * math.pi) + self.logvar_w + 1))
        logqb = 0.
        if self.use_bias:
            logqb = -torch.sum(.5 *
                               (math.log(2 * math.pi) + self.logvar_bias + 1))
        return logqw + logqb

    def kldiv_aux(self):
        z = self.sample_z(1)
        z = z.view(self.in_features)

        logqm = -torch.sum(.5 * (math.log(2 * math.pi) + self.qz_logvar + 1))
        logqm = logqm.add(-torch.sum(F.sigmoid(z.exp().add(-1).log()).log()))

        logpm = torch.sum(
            2 * math.lgamma(.5 * (self.dof + 1)) - math.lgamma(.5 * self.dof) -
            math.log(self.prior_std_z) - .5 * math.log(self.dof * math.pi) -
            .5 * (self.dof + 1) * torch.log(1. + z.pow(2) /
                                            (self.dof * self.prior_std_z**2)))

        return logpm - logqm

    def kldiv(self):
        return self.kldiv_aux() + self.eq_logpw() - self.eq_logqw()

    def get_eps(self, size):
        eps = self.floatTensor(size).normal_()
        eps = Variable(eps)
        return eps

    def sample_z(self, batch_size):
        z = self.qz_mean.view(1, self.in_features)
        if self.training:
            eps = self.get_eps(self.floatTensor(batch_size, self.in_features))
            z = z + eps.mul(
                self.qz_logvar.view(1, self.in_features).mul(0.5).exp_())
        return F.softplus(z)

    def sample_W(self):
        W = self.mean_w
        if self.training:
            eps = self.get_eps(self.mean_w.size())
            W = W.add(eps.mul(self.logvar_w.mul(0.5).exp_()))
        return W

    def sample_b(self):
        b = self.mean_bias
        if self.training:
            eps = self.get_eps(self.mean_bias.size())
            b = b.add(eps.mul(self.logvar_bias.mul(0.5).exp_()))
        return b

    def get_mean_x(self, input):
        mean_xin = input.mm(self.mean_w)
        if self.use_bias:
            mean_xin = mean_xin.add(self.mean_bias.view(1, self.out_features))

        return mean_xin

    def get_var_x(self, input):
        var_xin = input.pow(2).mm(self.logvar_w.exp())
        if self.use_bias:
            var_xin = var_xin.add(self.logvar_bias.exp().view(
                1, self.out_features))

        return var_xin

    def forward(self, input):
        # sampling
        batch_size = input.size(0)
        z = self.sample_z(batch_size)
        xin = input.mul(z)
        mean_xin = self.get_mean_x(xin)
        output = mean_xin
        if self.training:
            var_xin = self.get_var_x(xin)
            eps = self.get_eps(self.floatTensor(batch_size, self.out_features))
            output = output.add(var_xin.sqrt().mul(eps))
        return output

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ', dof: ' \
               + str(self.dof) + ', prior_std_z: ' \
               + str(self.prior_std_z) + ')'