class MAPDense(Module): def __init__(self, in_features, out_features, bias=True, weight_decay=1., **kwargs): super(MAPDense, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = Parameter(torch.Tensor(in_features, out_features)) self.weight_decay = weight_decay if bias: self.bias = Parameter(torch.Tensor(out_features)) else: self.register_parameter('bias', None) self.floatTensor = torch.FloatTensor if not torch.cuda.is_available( ) else torch.cuda.FloatTensor self.reset_parameters() print(self) def reset_parameters(self): init.kaiming_normal_(self.weight, mode='fan_out') if self.bias is not None: self.bias.data.normal_(0, 1e-2) def constrain_parameters(self, thres_std=1.): pass def eq_logpw(self, **kwargs): logpw = -torch.sum(self.weight_decay * .5 * (self.weight.pow(2))) logpb = 0 if self.bias is not None: logpb = -torch.sum(self.weight_decay * .5 * (self.bias.pow(2))) return logpw + logpb def eq_logqw(self): return 0. def kldiv_aux(self): return 0. def kldiv(self): return self.eq_logpw() - self.eq_logqw() + self.kldiv_aux() def forward(self, input): output = input.mm(self.weight) if self.bias is not None: output.add_(self.bias.view(1, self.out_features).expand_as(output)) return output def __repr__(self): return self.__class__.__name__ + ' (' \ + str(self.in_features) + ' -> ' \ + str(self.out_features) + ')'
class stimGLM(Poisson): def __init__(self, input_dim=(12, 1), num_directions=12, output_dim=128, d2t=0.1, **kwargs): super().__init__() self.save_hyperparameters() self.directionTuning = Parameter(torch.Tensor( size = (self.hparams.num_directions,self.hparams.output_dim) )) self.directionKernel = Parameter(torch.Tensor( size = (self.hparams.input_dim[0], self.hparams.output_dim) )) self.bias = Parameter(torch.Tensor( size = (1, self.hparams.output_dim) )) self.spikeNL = nn.Softplus() self.directionTuning.data = torch.randn(self.directionTuning.shape) self.directionKernel.data = torch.randn(self.directionKernel.shape) self.bias.data = torch.rand(self.bias.shape) def regularizer(self): d2tdir = self.directionKernel.diff(axis=0).pow(2).sum() l2dir = self.directionTuning.pow(2).sum() return self.hparams.d2t * d2tdir + self.hparams.d2t * l2dir # self.contrast.weight.data.diff() def forward(self, sample): x = torch.einsum('nld,lc->ndc', sample['direction'], self.directionKernel) x = torch.einsum('ndc,dc->nc', x, self.directionTuning) x = self.spikeNL(x + self.bias) return x
class Lap2D(nn.Module): """ Isotropic (Lap2D shape) Lap2D filter """ def __init__(self, w, n_gaussian,learn_amplitude=False): super(Lap2D, self).__init__() self.xes = torch.FloatTensor(range(int(-w / 2)+1, int(w / 2) + 1)).unsqueeze(-1) self.xes = self.xes.repeat(self.xes.size(0), 1, n_gaussian) self.yes = self.xes.transpose(1, 0) self.xypod = Parameter(self.xes * self.yes, requires_grad=False) self.xes = Parameter(self.xes.pow(2), requires_grad=False) self.yes = Parameter(self.yes.pow(2), requires_grad=False) self.padding = int(w / 2) self.s = Parameter(torch.randn(n_gaussian).float(), requires_grad=True) print("Current Lap2D") def weights_init(self): self.s.data.normal_(1.,0.3) def get_gaussian(self,s): #return (- (self.xes + self.yes) / (2 * s.pow(2))).exp()/(2.4569*s) return (- (self.xes + self.yes)*s.pow(2) / 2).exp()/(2.4569)*s def get_filter(self, s=None): """ :param s: :param amplitude: :return: """ if s is None: s = self.s eps=1e-3 k=(self.s.pow(2)/eps) filters = self.get_gaussian(self.s)\ - self.get_gaussian(self.s+eps) return (k*filters).transpose(0, 2).unsqueeze(1).contiguous() def forward(self, x): filters = self.get_filter(self.s) return F.conv2d(x, filters, padding=self.padding, groups=x.size(1))
class AngleSoftmax(nn.Module): def __init__(self, input_size, output_size, normalize=True, m=4, lambda_max=1000.0, lambda_min=5.0, power=1.0, gamma=0.1, loss_weight=1.0): """ :param input_size: Input channel size. :param output_size: Number of Class. :param normalize: Whether do weight normalization. :param m: An integer, specifying the margin type, take value of [0,1,2,3,4,5]. :param lambda_max: Starting value for lambda. :param lambda_min: Minimum value for lambda. :param power: Decreasing strategy for lambda. :param gamma: Decreasing strategy for lambda. :param loss_weight: Loss weight for this loss. """ super(AngleSoftmax, self).__init__() self.loss_weight = loss_weight self.normalize = normalize self.weight = Parameter(torch.Tensor(int(output_size), input_size)) nn.init.kaiming_uniform_(self.weight, 1.0) self.m = m self.it = 0 self.LambdaMin = lambda_min self.LambdaMax = lambda_max self.gamma = gamma self.power = power def forward(self, x, y): if self.normalize: wl = self.weight.pow(2).sum(1).pow(0.5) wn = self.weight / wl.view(-1, 1) self.weight.data.copy_(wn.data) if self.training: lamb = max(self.LambdaMin, self.LambdaMax / (1 + self.gamma * self.it)**self.power) self.it += 1 phi_kernel = PhiKernel(self.m, lamb) feat = phi_kernel(x, self.weight, y) loss = F.nll_loss(F.log_softmax(feat, dim=1), y) else: feat = x.mm(self.weight.t()) self.prob = F.log_softmax(feat, dim=1) loss = F.nll_loss(self.prob, y) return loss.mul_(self.loss_weight)
class VBLinear(nn.Module): def __init__(self, in_features, out_features, prior_prec=10, map=True): super(VBLinear, self).__init__() self.n_in = in_features self.n_out = out_features self.prior_prec = prior_prec self.map = map self.bias = nn.Parameter(th.Tensor(out_features)) self.mu_w = Parameter(th.Tensor(out_features, in_features)) self.logsig2_w = nn.Parameter(th.Tensor(out_features, in_features)) self.reset_parameters() def reset_parameters(self): # TODO: Adapt to the newest pytorch initializations stdv = 1. / math.sqrt(self.mu_w.size(1)) self.mu_w.data.normal_(0, stdv) self.logsig2_w.data.zero_().normal_(-9, 0.001) # var init via Louizos self.bias.data.zero_() def KL(self, loguniform=False): if loguniform: k1 = 0.63576 k2 = 1.87320 k3 = 1.48695 log_alpha = self.logsig2_w - 2 * th.log(self.mu_w.abs() + 1e-8) kl = -th.sum(k1 * th.sigmoid(k2 + k3 * log_alpha) - 0.5 * F.softplus(-log_alpha) - k1) else: logsig2_w = self.logsig2_w.clamp(-11, 11) kl = 0.5 * (self.prior_prec * (self.mu_w.pow(2) + logsig2_w.exp()) - logsig2_w - 1 - np.log(self.prior_prec)).sum() return kl def forward(self, input): # Sampling free forward pass only if MAP prediction and no training rounds if self.map and not self.training: return F.linear(input, self.mu_w, self.bias) else: mu_out = F.linear(input, self.mu_w, self.bias) logsig2_w = self.logsig2_w.clamp(-11, 11) s2_w = logsig2_w.exp() var_out = F.linear(input.pow(2), s2_w) + 1e-8 return mu_out + var_out.sqrt() * th.randn_like(mu_out) def __repr__(self): return self.__class__.__name__ + ' (' \ + str(self.n_in) + ' -> ' \ + str(self.n_out) + ')'
class WeightNormalizedLinear(Module): def __init__(self, in_features, out_features, scale=True, bias=True, init_factor=1, init_scale=1): super(WeightNormalizedLinear, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = Parameter(torch.Tensor(out_features, in_features)) if bias: self.bias = Parameter(torch.zeros(1, out_features)) else: self.register_parameter('bias', None) if scale: self.scale = Parameter( torch.Tensor(1, out_features).fill_(init_scale)) else: self.register_parameter('scale', None) self.reset_parameters(init_factor) def reset_parameters(self, factor): stdv = 1. * factor / math.sqrt(self.weight.size(1)) self.weight.data.uniform_(-stdv, stdv) if self.bias is not None: self.bias.data.uniform_(-stdv, stdv) def weight_norm(self): return self.weight.pow(2).sum(1).add(1e-6).sqrt() def norm_scale_bias(self, input): output = input.div(self.weight_norm().transpose(0, 1).expand_as(input)) if self.scale is not None: output = output.mul(self.scale.expand_as(input)) if self.bias is not None: output = output.add(self.bias.expand_as(input)) return output def forward(self, input): return self.norm_scale_bias(F.linear(input, self.weight)) def __repr__(self): return self.__class__.__name__ + ' (' \ + str(self.in_features) + ' -> ' \ + str(self.out_features) + ')'
class BinaryGatedLinear(Module): """ Linear layer with stochastic binary gates """ def __init__(self, in_features, out_features, l0_strength=1., l2_strength=1., learn_weight=True, bias=True, droprate_init=0.5, random_weight=True, deterministic=False, use_baseline_bias=False, optimize_inference=False, one_sample_per_item=False, **kwargs): """ :param in_features: Input dimensionality :param out_features: Output dimensionality :param bias: Whether we use a bias :param l2_strength: Strength of the L2 penalty :param droprate_init: Dropout rate that the gates will be initialized to :param l0_strength: Strength of the L0 penalty """ super(BinaryGatedLinear, self).__init__() self.in_features = in_features self.out_features = out_features self.l0_strength = l0_strength self.l2_strength = l2_strength self.deterministic = deterministic self.use_baseline_bias = use_baseline_bias self.optimize_inference = optimize_inference self.one_sample_per_item = one_sample_per_item self.random_weight = random_weight if random_weight: exc_weight = torch.Tensor(out_features, in_features) inh_weight = torch.Tensor(out_features, in_features) else: exc_weight = torch.ones(out_features, in_features) inh_weight = torch.ones(out_features, in_features) if learn_weight: self.exc_weight = Parameter(exc_weight) self.inh_weight = Parameter(inh_weight) else: self.register_buffer("exc_weight", exc_weight) self.register_buffer("inh_weight", inh_weight) self.exc_p1 = Parameter(torch.Tensor(out_features, in_features)) self.inh_p1 = Parameter(torch.Tensor(out_features, in_features)) self.droprate_init = droprate_init if droprate_init != 0. else 0.5 self.use_bias = bias if bias: self.bias = Parameter(torch.Tensor(out_features)) self.reset_parameters() def reset_parameters(self): if self.random_weight: init.kaiming_normal_(self.exc_weight, mode="fan_out") init.kaiming_normal_(self.inh_weight, mode="fan_out") self.exc_weight.data.abs_() self.inh_weight.data.abs_() self.exc_p1.data.normal_(1 - self.droprate_init, 1e-2) self.inh_p1.data.normal_(1 - self.droprate_init, 1e-2) if self.use_bias: self.bias.data.fill_(0) def constrain_parameters(self, **kwargs): self.exc_weight.data.clamp_(min=0.) self.inh_weight.data.clamp_(min=0.) def get_gate_probabilities(self): exc_p1 = torch.clamp(self.exc_p1.data, min=0., max=1.) inh_p1 = torch.clamp(self.inh_p1.data, min=0., max=1.) return exc_p1, inh_p1 def weight_size(self): return self.exc_weight.size() def regularization(self): """ Expected L0 norm under the stochastic gates, takes into account and re-weights also a potential L2 penalty """ if self.l0_strength > 0 or self.l2_strength > 0: # Clamp these, but do it in a way that still always propagates the # gradient. exc_p1 = self.exc_p1.clone() torch.clamp(exc_p1.data, min=0, max=1, out=exc_p1.data) inh_p1 = self.inh_p1.clone() torch.clamp(inh_p1.data, min=0, max=1, out=inh_p1.data) if self.l2_strength == 0: return self.l0_strength * (exc_p1 + inh_p1).sum() else: exc_weight_decay_ungated = (.5 * self.l2_strength * self.exc_weight.pow(2)) inh_weight_decay_ungated = (.5 * self.l2_strength * self.inh_weight.pow(2)) exc_weight_l2_l0 = torch.sum( (exc_weight_decay_ungated + self.l0_strength) * exc_p1) inh_weight_l2_l0 = torch.sum( (inh_weight_decay_ungated + self.l0_strength) * inh_p1) bias_l2 = (0 if not self.use_bias else torch.sum( .5 * self.l2_strength * self.bias.pow(2))) return exc_weight_l2_l0 + inh_weight_l2_l0 + bias_l2 else: return 0 def get_inference_mask(self): exc_p1, inh_p1 = self.get_gate_probabilities() if self.deterministic: exc_mask = (exc_p1 >= 0.5).float() inh_mask = (inh_p1 >= 0.5).float() return exc_mask, inh_mask else: exc_count1 = exc_p1.sum(dim=1).round().int() inh_count1 = inh_p1.sum(dim=1).round().int() # pytorch doesn't offer topk with varying k values. exc_mask = torch.zeros_like(exc_p1) inh_mask = torch.zeros_like(inh_p1) for i in range(exc_count1.size()[0]): _, exc_indices = torch.topk(exc_p1[i], exc_count1[i].item()) _, inh_indices = torch.topk(inh_p1[i], inh_count1[i].item()) exc_mask[i].scatter_(-1, exc_indices, 1) inh_mask[i].scatter_(-1, inh_indices, 1) return exc_mask, inh_mask def sample_weight_and_bias(self): if self.training or not self.optimize_inference: w = (sample_weight(self.exc_p1, self.exc_weight, self.deterministic) - sample_weight(self.inh_p1, self.inh_weight, self.deterministic)) else: exc_mask, inh_mask = self.get_inference_mask() w = exc_mask * self.exc_weight - inh_mask * self.inh_weight b = None if self.use_baseline_bias: b = -w.sum(dim=-1) / 2 if self.use_bias: b = (b + self.bias if b is not None else self.bias) return w, b def forward(self, x): if self.one_sample_per_item and self.training and len(x.size()) > 1: results = [] for i in range(x.size(0)): w, b = self.sample_weight_and_bias() results.append(F.linear(x[i:i + 1], w, b)) return torch.cat(results) else: w, b = self.sample_weight_and_bias() return F.linear(x, w, b) return self._forward(x) def get_expected_nonzeros(self): exc_p1, inh_p1 = self.get_gate_probabilities() # Flip two coins with probabilities pi_1 and pi_2. What is the # probability one of them is 1? # # 1 - (1 - pi_1)*(1 - pi_2) # = 1 - 1 + pi_1 + pi_2 - pi_1*pi_2 # = pi_1 + pi_2 - pi_1*pi_2 p1 = exc_p1 + inh_p1 - (exc_p1 * inh_p1) return p1.sum(dim=1).detach() def get_inference_nonzeros(self): exc_mask, inh_mask = self.get_inference_mask() return torch.sum(exc_mask.int() | inh_mask.int(), dim=1) def count_inference_flops(self): # For each unit, multiply with its n inputs then do n - 1 additions. # To capture the -1, subtract it, but only in cases where there is at # least one weight. nz_by_unit = self.get_inference_nonzeros() multiplies = torch.sum(nz_by_unit) adds = multiplies - torch.sum(nz_by_unit > 0) return multiplies.item(), adds.item()
class CLTLinear(nn.Module): def __init__(self, in_features, out_features, prior_prec=10, relu_act=True, elu_act=False): super(CLTLinear, self).__init__() self.n_in = in_features self.n_out = out_features self.prior_prec = prior_prec assert not ( relu_act and elu_act ) # A single layer can only do either relu or elu activation self.relu_act = relu_act self.elu_act = elu_act self.bias = nn.Parameter(th.Tensor(out_features)) self.mu_w = Parameter(th.Tensor(out_features, in_features)) self.logsig2_w = nn.Parameter(th.Tensor(out_features, in_features)) self.reset_parameters() def reset_parameters(self): # TODO: Adapt to the newest pytorch initializations stdv = 1. / math.sqrt(self.mu_w.size(1)) self.mu_w.data.normal_(0, stdv) self.logsig2_w.data.zero_().normal_(-9, 0.001) self.bias.data.zero_() def KL(self, loguniform=False): if loguniform: k1 = 0.63576 k2 = 1.87320 k3 = 1.48695 log_alpha = self.logsig2_w - 2 * th.log(self.mu_w.abs() + 1e-8) kl = -th.sum(k1 * F.sigmoid(k2 + k3 * log_alpha) - 0.5 * F.softplus(-log_alpha) - k1) else: logsig2_w = self.logsig2_w.clamp(-11, 11) kl = 0.5 * (self.prior_prec * (self.mu_w.pow(2) + logsig2_w.exp()) - logsig2_w - 1 - np.log(self.prior_prec)).sum() return kl def cdf(self, x, mu=0., sig=1.): return 0.5 * (1 + th.erf((x - mu) / (sig * math.sqrt(2)))) def pdf(self, x, mu=0., sig=1.): return (1 / (math.sqrt(2 * math.pi) * sig)) * th.exp(-0.5 * ( (x - mu) / sig).pow(2)) def relu_moments(self, mu, sig): alpha = mu / sig cdf = self.cdf(alpha) pdf = self.pdf(alpha) relu_mean = mu * cdf + sig * pdf relu_var = (sig.pow(2) + mu.pow(2)) * cdf + mu * sig * pdf - relu_mean.pow(2) relu_var.clamp_(1e-8) # Avoid negative variance due to numerics return relu_mean, relu_var def elu_moments_orig(self, mu, sig): # the original method without simplifications sig2 = sig.pow(2) elu_mean = th.exp(mu.clamp_max(10) + sig2 / 2) * self.cdf( -(mu + sig2) / sig) - self.cdf(-mu / sig) elu_mean += mu * self.cdf(mu / sig) + sig * self.pdf(mu / sig) elu_var = th.exp(2 * mu.clamp_max(10) + 2 * sig2) * self.cdf( -(mu + 2 * sig2) / sig) elu_var += -2 * th.exp(mu.clamp_max(10) + sig2 / 2) * self.cdf( -(mu + sig2) / sig) elu_var += self.cdf(-mu / sig) elu_var += (sig2 + mu.pow(2)) * self.cdf( mu / sig) + mu * sig * self.pdf(mu / sig) elu_var += -elu_mean.pow(2) elu_var.clamp_min_(1e-8) # Avoid negative variance due to numerics return elu_mean, elu_var def elu_moments(self, mu, sig): # NOTE: For now it is without alpha or the selu extension! # Note: Takes roughly 3x as much time as the relu # Clamp the mus to avoid problems in the expectation sig2 = sig.pow(2) alpha = mu / sig cdf_alpha = self.cdf(alpha) pdf_alpha = self.pdf(alpha) cdf_malpha = 1 - cdf_alpha cdf_malphamsig = self.cdf(-alpha - sig) elu_mean = th.exp(mu.clamp_max(10) + sig2 / 2) * cdf_malphamsig - cdf_malpha elu_mean += mu * cdf_alpha + sig * pdf_alpha elu_var = th.exp(2 * mu.clamp_max(10) + 2 * sig2) * self.cdf(-alpha - 2 * sig) elu_var += -2 * th.exp(mu.clamp_max(10) + sig2 / 2) * cdf_malphamsig elu_var += cdf_malpha elu_var += (sig2 + mu.pow(2)) * cdf_alpha + mu * sig * pdf_alpha elu_var += -elu_mean.pow(2) elu_var.clamp_min_(1e-8) # Avoid negative variance due to numerics return elu_mean, elu_var def forward(self, mu_inp, var_inp=None): s2_w = self.logsig2_w.exp() mu_out = F.linear(mu_inp, self.mu_w, self.bias) if var_inp is None: var_out = F.linear(mu_inp.pow(2), s2_w) + 1e-8 else: var_out = F.linear(var_inp + mu_inp.pow(2), s2_w) + F.linear( var_inp, self.mu_w.pow(2)) + 1e-8 if self.relu_act: mu_out, var_out = self.relu_moments(mu_out, var_out.sqrt()) if self.elu_act: mu_out, var_out = self.elu_moments(mu_out, var_out.sqrt()) return mu_out, var_out # + 1e-8 Already provided in the moment computation def __repr__(self): return self.__class__.__name__ + ' (' \ + str(self.n_in) + ' -> ' \ + str(self.n_out) \ + f", activation={self.relu_act or self.elu_act}" \ + f" ({'relu' if self.relu_act else ('elu' if self.elu_act else '')}))"
class BinaryGatedConv2d(Module): """ Convolutional layer with binary stochastic gates """ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, learn_weight=True, bias=True, droprate_init=0.5, l2_strength=1., l0_strength=1., random_weight=True, deterministic=False, use_baseline_bias=False, optimize_inference=True, one_sample_per_item=False, **kwargs): """ :param in_channels: Number of input channels :param out_channels: Number of output channels :param kernel_size: Size of the kernel :param stride: Stride for the convolution :param padding: Padding for the convolution :param dilation: Dilation factor for the convolution :param groups: How many groups we will assume in the convolution :param bias: Whether we will use a bias :param droprate_init: Dropout rate that the gates will be initialized to :param l2_strength: Strength of the L2 penalty :param l0_strength: Strength of the L0 penalty """ super(BinaryGatedConv2d, self).__init__() if in_channels % groups != 0: raise ValueError("in_channels must be divisible by groups") if out_channels % groups != 0: raise ValueError("out_channels must be divisible by groups") self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = pair(kernel_size) self.stride = pair(stride) self.padding = pair(padding) self.dilation = pair(dilation) self.output_padding = pair(0) self.groups = groups self.l2_strength = l2_strength self.l0_strength = l0_strength self.droprate_init = droprate_init if droprate_init != 0. else 0.5 self.deterministic = deterministic self.use_baseline_bias = use_baseline_bias self.optimize_inference = optimize_inference self.one_sample_per_item = one_sample_per_item self.random_weight = random_weight if random_weight: exc_weight = torch.Tensor(out_channels, in_channels // groups, *self.kernel_size) inh_weight = torch.Tensor(out_channels, in_channels // groups, *self.kernel_size) else: exc_weight = torch.ones(out_channels, in_channels // groups, *self.kernel_size) inh_weight = torch.ones(out_channels, in_channels // groups, *self.kernel_size) if learn_weight: self.exc_weight = Parameter(exc_weight) self.inh_weight = Parameter(inh_weight) else: self.register_buffer("exc_weight", exc_weight) self.register_buffer("inh_weight", inh_weight) self.exc_p1 = Parameter( torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) self.inh_p1 = Parameter( torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) self.dim_z = out_channels self.input_shape = None self.use_bias = bias if bias: self.bias = Parameter(torch.Tensor(out_channels)) self.reset_parameters() def reset_parameters(self): if self.random_weight: init.kaiming_normal_(self.exc_weight, mode="fan_out") init.kaiming_normal_(self.inh_weight, mode="fan_out") self.exc_weight.data.abs_() self.inh_weight.data.abs_() self.exc_p1.data.normal_(1 - self.droprate_init, 1e-2) self.inh_p1.data.normal_(1 - self.droprate_init, 1e-2) if self.use_bias: self.bias.data.fill_(0) def constrain_parameters(self, **kwargs): self.exc_weight.data.clamp_(min=0.) self.inh_weight.data.clamp_(min=0.) def weight_size(self): return self.exc_weight.size() def regularization(self): """ Expected L0 norm under the stochastic gates, takes into account and re-weights also a potential L2 penalty """ if self.l0_strength > 0 or self.l2_strength > 0: # Clamp these, but do it in a way that still always propagates the # gradient. exc_p1 = self.exc_p1.clone() torch.clamp(exc_p1.data, min=0, max=1, out=exc_p1.data) inh_p1 = self.inh_p1.clone() torch.clamp(inh_p1.data, min=0, max=1, out=inh_p1.data) if self.l2_strength == 0: return self.l0_strength * (exc_p1 + inh_p1).sum() else: exc_weight_decay_ungated = (.5 * self.l2_strength * self.exc_weight.pow(2)) inh_weight_decay_ungated = (.5 * self.l2_strength * self.inh_weight.pow(2)) exc_weight_l2_l0 = torch.sum( (exc_weight_decay_ungated + self.l0_strength) * exc_p1) inh_weight_l2_l0 = torch.sum( (inh_weight_decay_ungated + self.l0_strength) * inh_p1) bias_l2 = (0 if not self.use_bias else torch.sum( .5 * self.l2_strength * self.bias.pow(2))) return exc_weight_l2_l0 + inh_weight_l2_l0 + bias_l2 else: return 0 def get_gate_probabilities(self): exc_p1 = torch.clamp(self.exc_p1.data, min=0., max=1.) inh_p1 = torch.clamp(self.inh_p1.data, min=0., max=1.) return exc_p1, inh_p1 def get_inference_mask(self): exc_p1, inh_p1 = self.get_gate_probabilities() if self.deterministic: exc_mask = (exc_p1 >= 0.5).float() inh_mask = (inh_p1 >= 0.5).float() return exc_mask, inh_mask else: exc_count1 = exc_p1.sum( dim=tuple(range(1, len(exc_p1.shape)))).round().int() inh_count1 = inh_p1.sum( dim=tuple(range(1, len(inh_p1.shape)))).round().int() # pytorch doesn't offer topk with varying k values. exc_mask = torch.zeros_like(exc_p1) inh_mask = torch.zeros_like(inh_p1) for i in range(exc_count1.size()[0]): _, exc_indices = torch.topk(exc_p1[i].flatten(), exc_count1[i].item()) _, inh_indices = torch.topk(inh_p1[i].flatten(), inh_count1[i].item()) exc_mask[i].flatten().scatter_(-1, exc_indices, 1) inh_mask[i].flatten().scatter_(-1, inh_indices, 1) return exc_mask, inh_mask def sample_weight_and_bias(self, samples=1): if self.training or not self.optimize_inference: w = (sample_weight(self.exc_p1, self.exc_weight, self.deterministic, samples) - sample_weight(self.inh_p1, self.inh_weight, self.deterministic, samples)) else: exc_mask, inh_mask = self.get_inference_mask() w = exc_mask * self.exc_weight - inh_mask * self.inh_weight b = None if self.use_baseline_bias: b = -w.sum(dim=(-3, -2, -1)) / 2 if self.use_bias: b = (b + self.bias if b is not None else self.bias) return w, b def forward(self, x): if self.input_shape is None: self.input_shape = x.size() if self.one_sample_per_item and self.training and len(x.size()) > 3: w, b = self.sample_weight_and_bias(x.size(0)) if self.use_baseline_bias: b = b.view(x.size(0) * self.out_channels) else: b = b.repeat(x.size(0)) x_ = x.view(1, x.size(0) * x.size(1), *x.size()[2:]) w_ = w.view(w.size(0) * w.size(1), *w.size()[2:]) result = F.conv2d(x_, w_, b, self.stride, self.padding, self.dilation, x.size(0) * self.groups) return result.view(x.size(0), self.out_channels, *result.size()[2:]) else: w, b = self.sample_weight_and_bias() return F.conv2d(x, w, b, self.stride, self.padding, self.dilation, self.groups) def get_expected_nonzeros(self): exc_p1, inh_p1 = self.get_gate_probabilities() # Flip two coins with probabilities pi_1 and pi_2. What is the # probability one of them is 1? # # 1 - (1 - pi_1)*(1 - pi_2) # = 1 - 1 + pi_1 + pi_2 - pi_1*pi_2 # = pi_1 + pi_2 - pi_1*pi_2 p1 = exc_p1 + inh_p1 - (exc_p1 * inh_p1) return p1.sum(dim=tuple(range(1, len(p1.shape)))).detach() def get_inference_nonzeros(self): exc_mask, inh_mask = self.get_inference_mask() return torch.sum(exc_mask.int() | inh_mask.int(), dim=tuple(range(1, len(exc_mask.shape)))) def count_inference_flops(self): # For each unit, multiply with n inputs then do n - 1 additions. # Only subtract 1 in cases where is at least one weight. nz_by_unit = self.get_inference_nonzeros() multiplies_per_instance = torch.sum(nz_by_unit) adds_per_instance = multiplies_per_instance - torch.sum(nz_by_unit > 0) # for rows instances = ((self.input_shape[-2] - self.kernel_size[0] + 2 * self.padding[0]) / self.stride[0]) + 1 # multiplying with cols instances *= ((self.input_shape[-1] - self.kernel_size[1] + 2 * self.padding[1]) / self.stride[1]) + 1 multiplies = multiplies_per_instance * instances adds = adds_per_instance * instances return multiplies.item(), adds.item()
class _ConvNdGroupNJ(Module): """Convolutional Group Normal-Jeffrey's layers (aka Group Variational Dropout). References: [1] Kingma, Diederik P., Tim Salimans, and Max Welling. "Variational dropout and the local reparameterization trick." NIPS (2015). [2] Molchanov, Dmitry, Arsenii Ashukha, and Dmitry Vetrov. "Variational Dropout Sparsifies Deep Neural Networks." ICML (2017). [3] Louizos, Christos, Karen Ullrich, and Max Welling. "Bayesian Compression for Deep Learning." NIPS (2017). """ def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, bias, init_weight, init_bias, cuda=False, clip_var=None): super(_ConvNdGroupNJ, self).__init__() if in_channels % groups != 0: raise ValueError('in_channels must be divisible by groups') if out_channels % groups != 0: raise ValueError('out_channels must be divisible by groups') self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size self.stride = stride self.padding = padding self.dilation = dilation self.transposed = transposed self.output_padding = output_padding self.groups = groups self.cuda = cuda self.clip_var = clip_var self.deterministic = False # flag is used for compressed inference if transposed: self.weight_mu = Parameter(torch.Tensor( in_channels, out_channels // groups, *kernel_size)) self.weight_logvar = Parameter(torch.Tensor( in_channels, out_channels // groups, *kernel_size)) else: self.weight_mu = Parameter(torch.Tensor( out_channels, in_channels // groups, *kernel_size)) self.weight_logvar = Parameter(torch.Tensor( out_channels, in_channels // groups, *kernel_size)) self.bias_mu = Parameter(torch.Tensor(out_channels)) self.bias_logvar = Parameter(torch.Tensor(out_channels)) self.z_mu = Parameter(torch.Tensor(self.out_channels)) self.z_logvar = Parameter(torch.Tensor(self.out_channels)) self.reset_parameters(init_weight, init_bias) # activations for kl self.sigmoid = nn.Sigmoid() self.softplus = nn.Softplus() # numerical stability param self.epsilon = 1e-8 def reset_parameters(self, init_weight, init_bias): # init means n = self.in_channels for k in self.kernel_size: n *= k stdv = 1. / math.sqrt(n) # init means if init_weight is not None: self.weight_mu.data = init_weight else: self.weight_mu.data.uniform_(-stdv, stdv) if init_bias is not None: self.bias_mu.data = init_bias else: self.bias_mu.data.fill_(0) # inti z self.z_mu.data.normal_(1, 1e-2) # init logvars self.z_logvar.data.normal_(-9, 1e-2) self.weight_logvar.data.normal_(-9, 1e-2) self.bias_logvar.data.normal_(-9, 1e-2) def clip_variances(self): if self.clip_var: self.weight_logvar.data.clamp_(max=math.log(self.clip_var)) self.bias_logvar.data.clamp_(max=math.log(self.clip_var)) def get_log_dropout_rates(self): log_alpha = self.z_logvar - torch.log(self.z_mu.pow(2) + self.epsilon) return log_alpha def compute_posterior_params(self): weight_var, z_var = self.weight_logvar.exp(), self.z_logvar.exp() self.post_weight_var = self.z_mu.pow(2) * weight_var + z_var * self.weight_mu.pow(2) + z_var * weight_var self.post_weight_mu = self.weight_mu * self.z_mu return self.post_weight_mu, self.post_weight_var def kl_divergence(self): # KL(q(z)||p(z)) # we use the kl divergence approximation given by [2] Eq.(14) k1, k2, k3 = 0.63576, 1.87320, 1.48695 log_alpha = self.get_log_dropout_rates() KLD = -torch.sum(k1 * self.sigmoid(k2 + k3 * log_alpha) - 0.5 * self.softplus(-log_alpha) - k1) # KL(q(w|z)||p(w|z)) # we use the kl divergence given by [3] Eq.(8) KLD_element = - 0.5 * self.weight_logvar + 0.5 * (self.weight_logvar.exp() + self.weight_mu.pow(2)) - 0.5 KLD += torch.sum(KLD_element) # KL bias KLD_element = - 0.5 * self.bias_logvar + 0.5 * (self.bias_logvar.exp() + self.bias_mu.pow(2)) - 0.5 KLD += torch.sum(KLD_element) return KLD def __repr__(self): s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}' ', stride={stride}') if self.padding != (0,) * len(self.padding): s += ', padding={padding}' if self.dilation != (1,) * len(self.dilation): s += ', dilation={dilation}' if self.output_padding != (0,) * len(self.output_padding): s += ', output_padding={output_padding}' if self.groups != 1: s += ', groups={groups}' if self.bias is None: s += ', bias=False' s += ')' return s.format(name=self.__class__.__name__, **self.__dict__)
class group_relaxed_TF1Conv2d(Module): """Implementation of TF1 regularization for the feature maps of a convolutional layer""" def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, lamba=1., alpha=1., beta=4., weight_decay = 1., **kwargs): """ :param in_channels: Number of input channels :param out_channels: Number of output channels :param kernel_size: size of the kernel :param stride: stride for the convolution :param padding: padding for the convolution :param dilation: dilation factor for the convolution :param groups: how many groups we will assume in the convolution :param bias: whether we will use a bias :param lamba: strength of the TFL regularization """ super(group_relaxed_TF1Conv2d, self).__init__() self.floatTensor = torch.FloatTensor if not torch.cuda.is_available() else torch.cuda.FloatTensor self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = pair(kernel_size) self.stride = pair(stride) self.padding = pair(padding) self.dilation = pair(dilation) self.output_padding = pair(0) self.groups = groups self.lamba = lamba self.alpha = alpha self.beta = beta self.lamba1 = self.lamba/self.beta self.weight_decay = weight_decay self.weight = Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) self.u = torch.rand(out_channels, in_channels // groups, *self.kernel_size) self.u = self.u.to('cuda') if bias: self.bias = Parameter(torch.Tensor(out_channels)) else: self.register_parameter('bias', None) self.reset_parameters() self.input_shape = None print(self) def reset_parameters(self): init.kaiming_normal(self.weight, mode='fan_in') if self.bias is not None: self.bias.data.normal_(0,1e-2) def phi(self,x): phi_x = torch.acos(1-27*(self.lamba1*self.alpha*(self.alpha+1))/(2*(self.alpha+x.abs())**3)) return phi_x def g(self,x): g_x = x.sign()*(2/3*(self.alpha + x.abs())*torch.cos(self.phi(x)/3)-2*self.alpha/3+x.abs()/3) return g_x def constrain_parameters(self, thres_std=1.): #self.weight.data = F.normalize(self.weight.data, p=2, dim=1) #print(torch.sum(self.weight.pow(2))) if self.lamba1 <= (self.alpha**2)/(2*(self.alpha+1)): t = self.lamba1*(self.alpha+1)/(self.alpha) else: t = np.sqrt(2*self.lamba1*(self.alpha+1))-self.alpha/2 self.u.data = self.weight.data.clone() self.u.data[self.u.data.abs() <=t] = 0 g_result = self.g(self.u) self.u.data[self.u.data.abs() > t] = g_result[self.u.data.abs() > t] def grow_beta(self, growth_factor): self.beta = self.beta*growth_factor self.lamba1 = self.lamba/self.beta def _reg_w(self, **kwargs): logpw = -self.beta*torch.sum(0.5*self.weight.add(-self.u).pow(2))-self.lamba*np.sqrt(self.in_channels*self.kernel_size[0]*self.kernel_size[1])*torch.sum(torch.pow(torch.sum(self.weight.pow(2),3).sum(2).sum(1),0.5)) logpb = 0 if self.bias is not None: logpb = - torch.sum(self.weight_decay * .5 * (self.bias.pow(2))) return logpw+logpb def regularization(self): return self._reg_w() def count_zero_u(self): total = np.prod(self.u.size()) zero = total - self.u.nonzero().size(0) return zero def count_zero_w(self): return torch.sum((self.weight.abs()<1e-5).int()).item() def count_active_neuron(self): return torch.sum((torch.sum(self.weight.abs(),3).sum(2).sum(1)/(self.in_channels*self.kernel_size[0]*self.kernel_size[1]))>1e-5).item() def count_total_neuron(self): return self.out_channels def count_weight(self): return np.prod(self.u.size()) def count_expected_flops_and_l0(self): #ppos = self.out_channels ppos = torch.sum(torch.sum(self.weight.abs(),3).sum(2).sum(1)>0.001).item() n = self.kernel_size[0]*self.kernel_size[1]*self.in_channels flops_per_instance = n+(n-1) num_instances_per_filter = ((self.input_shape[1] -self.kernel_size[0]+2*self.padding[0])/self.stride[0]) + 1 num_instances_per_filter *=((self.input_shape[2] - self.kernel_size[1]+2*self.padding[1])/self.stride[1]) + 1 flops_per_filter = num_instances_per_filter * flops_per_instance expected_flops = flops_per_filter*ppos expected_l0 = n*ppos if self.bias is not None: expected_flops += num_instances_per_filter*ppos expected_l0 += ppos return expected_flops, expected_l0 def forward(self, input_): if self.input_shape is None: self.input_shape = input_.size() output = F.conv2d(input_, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) return output def __repr__(self): s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size} ' ', stride={stride}') if self.padding != (0,) * len(self.padding): s += ', padding={padding}' if self.dilation != (1,) * len(self.dilation): s += ', dilation={dilation}' if self.output_padding != (0,) * len(self.output_padding): s += ', output_padding={output_padding}' if self.groups != 1: s += ', groups={groups}' if self.bias is None: s += ', bias=False' s += ')' return s.format(name=self.__class__.__name__, **self.__dict__)
class HardConcreteGatedConv2d(Module): """ Convolutional layer with stochastic connections, as in https://arxiv.org/abs/1712.01312 """ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, learn_weight=True, droprate_init=0.5, temperature=(2 / 3), l2_strength=1., l0_strength=1., **kwargs): """ :param in_channels: Number of input channels :param out_channels: Number of output channels :param kernel_size: Size of the kernel :param stride: Stride for the convolution :param padding: Padding for the convolution :param dilation: Dilation factor for the convolution :param groups: How many groups we will assume in the convolution :param bias: Whether we will use a bias :param droprate_init: Dropout rate that the L0 gates will be initialized to :param temperature: Temperature of the concrete distribution :param l2_strength: Strength of the L2 penalty :param l0_strength: Strength of the L0 penalty """ super(HardConcreteGatedConv2d, self).__init__() if in_channels % groups != 0: raise ValueError("in_channels must be divisible by groups") if out_channels % groups != 0: raise ValueError("out_channels must be divisible by groups") self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = pair(kernel_size) self.stride = pair(stride) self.padding = pair(padding) self.dilation = pair(dilation) self.output_padding = pair(0) self.groups = groups self.l2_strength = l2_strength self.l0_strength = l0_strength self.droprate_init = droprate_init if droprate_init != 0. else 0.5 self.temperature = temperature self.floatTensor = (torch.FloatTensor if not torch.cuda.is_available() else torch.cuda.FloatTensor) self.use_bias = False weight = torch.Tensor(out_channels, in_channels // groups, *self.kernel_size) if learn_weight: self.weight = Parameter(weight) else: self.register_buffer("weight", weight) self.loga = Parameter( torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) self.dim_z = out_channels self.input_shape = None if bias: bias = torch.Tensor(out_channels) if learn_weight: self.bias = Parameter(bias) else: self.register_buffer("bias", bias) self.use_bias = True self.reset_parameters() def reset_parameters(self): init.kaiming_normal_(self.weight, mode="fan_in") self.loga.data.normal_( math.log(1 - self.droprate_init) - math.log(self.droprate_init), 1e-2) if self.use_bias: self.bias.data.fill_(0) def constrain_parameters(self, **kwargs): self.loga.data.clamp_(min=math.log(1e-2), max=math.log(1e2)) def cdf_qz(self, x): """Implements the CDF of the 'stretched' concrete distribution""" xn = (x - LIMIT_A) / (LIMIT_B - LIMIT_A) logits = math.log(xn) - math.log(1 - xn) return torch.sigmoid(logits * self.temperature - self.loga).clamp( min=EPSILON, max=1 - EPSILON) def quantile_concrete(self, x): """ Implements the quantile, aka inverse CDF, of the 'stretched' concrete distribution """ y = torch.sigmoid( (torch.log(x) - torch.log(1 - x) + self.loga) / self.temperature) return y * (LIMIT_B - LIMIT_A) + LIMIT_A def weight_size(self): return self.weight.size() def regularization(self): """ Expected L0 norm under the stochastic gates, takes into account and re-weights also a potential L2 penalty """ weight_decay_ungated = .5 * self.l2_strength * self.weight.pow(2) weight_l2_l0 = torch.sum( (weight_decay_ungated + self.l0_strength) * (1 - self.cdf_qz(0))) bias_l2 = (0 if not self.use_bias else torch.sum(.5 * self.l2_strength * self.bias.pow(2))) return weight_l2_l0 + bias_l2 def count_inference_flops(self): # For each unit, multiply with n inputs then do n - 1 additions. # Only subtract 1 in cases where is at least one weight. nz_by_unit = self.get_inference_nonzeros() multiplies_per_instance = torch.sum(nz_by_unit) adds_per_instance = multiplies_per_instance - torch.sum(nz_by_unit > 0) # for rows instances = ((self.input_shape[-2] - self.kernel_size[0] + 2 * self.padding[0]) / self.stride[0]) + 1 # multiplying with cols instances *= ((self.input_shape[-1] - self.kernel_size[1] + 2 * self.padding[1]) / self.stride[1]) + 1 multiplies = multiplies_per_instance * instances adds = adds_per_instance * instances return multiplies.item(), adds.item() def count_expected_flops_and_l0(self): """ Measures the expected floating point operations (FLOPs) and the expected L0 norm Copied from the original L0 paper code """ ppos = torch.sum(1 - self.cdf_qz(0)) # vector_length n = self.kernel_size[0] * self.kernel_size[1] * self.in_channels # (n: multiplications and n-1: additions) flops_per_instance = n + (n - 1) # for rows num_instances_per_filter = ( (self.input_shape[1] - self.kernel_size[0] + 2 * self.padding[0]) / self.stride[0]) + 1 # multiplying with cols num_instances_per_filter *= ( (self.input_shape[2] - self.kernel_size[1] + 2 * self.padding[1]) / self.stride[1]) + 1 flops_per_filter = num_instances_per_filter * flops_per_instance # multiply with number of filters expected_flops = flops_per_filter * ppos expected_l0 = n * ppos if self.use_bias: # since the gate is applied to the output we also reduce the bias # computation expected_flops += num_instances_per_filter * ppos expected_l0 += ppos return expected_flops.data[0], expected_l0.data[0] def get_eps(self, size): """Uniform random numbers for the concrete distribution""" eps = self.floatTensor(size).uniform_(EPSILON, 1 - EPSILON) eps = Variable(eps) return eps def sample_weight(self): if self.training: z = self.quantile_concrete( self.get_eps(self.floatTensor(self.loga.size()))) mask = F.hardtanh(z, min_val=0, max_val=1) else: pi = torch.sigmoid(self.loga) mask = F.hardtanh(pi * (LIMIT_B - LIMIT_A) + LIMIT_A, min_val=0, max_val=1) return mask * self.weight def forward(self, x): if self.input_shape is None: self.input_shape = x.size() return F.conv2d(x, self.sample_weight(), (self.bias if self.use_bias else None), self.stride, self.padding, self.dilation, self.groups) def get_expected_nonzeros(self): expected_gates = 1 - self.cdf_qz(0) return expected_gates.sum( dim=tuple(range(1, len(expected_gates.shape)))).detach() def get_inference_nonzeros(self): inference_gates = F.hardtanh(torch.sigmoid(self.loga) * (LIMIT_B - LIMIT_A) + LIMIT_A, min_val=0, max_val=1) return (inference_gates > 0).sum( dim=tuple(range(1, len(inference_gates.shape)))).detach() def __repr__(self): s = ( "{name}({in_channels}, {out_channels}, kernel_size={kernel_size}, " "stride={stride}, droprate_init={droprate_init}, " "temperature={temperature}, l2_strength={l2_strength}, " "l0_strength={l0_strength}") if self.padding != (0, ) * len(self.padding): s += ", padding={padding}" if self.dilation != (1, ) * len(self.dilation): s += ", dilation={dilation}" if self.output_padding != (0, ) * len(self.output_padding): s += ", output_padding={output_padding}" if self.groups != 1: s += ", groups={groups}" if not self.use_bias: s += ", bias=False" s += ")" return s.format(name=self.__class__.__name__, **self.__dict__)
class BinaryGatedConv2d(Module): """ Convolutional layer with binary stochastic gates """ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, learn_weight=True, bias=True, droprate_init=0.5, l2_strength=1., l0_strength=1., **kwargs): """ :param in_channels: Number of input channels :param out_channels: Number of output channels :param kernel_size: Size of the kernel :param stride: Stride for the convolution :param padding: Padding for the convolution :param dilation: Dilation factor for the convolution :param groups: How many groups we will assume in the convolution :param bias: Whether we will use a bias :param droprate_init: Dropout rate that the gates will be initialized to :param l2_strength: Strength of the L2 penalty :param l0_strength: Strength of the L0 penalty """ super(BinaryGatedConv2d, self).__init__() if in_channels % groups != 0: raise ValueError("in_channels must be divisible by groups") if out_channels % groups != 0: raise ValueError("out_channels must be divisible by groups") self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = pair(kernel_size) self.stride = pair(stride) self.padding = pair(padding) self.dilation = pair(dilation) self.output_padding = pair(0) self.groups = groups self.l2_strength = l2_strength self.l0_strength = l0_strength self.droprate_init = droprate_init if droprate_init != 0. else 0.5 self.floatTensor = (torch.FloatTensor if not torch.cuda.is_available() else torch.cuda.FloatTensor) self.use_bias = False weight = torch.Tensor(out_channels, in_channels // groups, *self.kernel_size) if learn_weight: self.weight = Parameter(weight) else: self.register_buffer("weight", weight) self.logit_p1 = Parameter( torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) self.dim_z = out_channels self.input_shape = None if bias: b = torch.Tensor(out_channels) if learn_weight: self.bias = Parameter(b) else: self.register_buffer("bias", b) self.use_bias = True self.reset_parameters() def reset_parameters(self): init.kaiming_normal_(self.weight, mode="fan_in") self.logit_p1.data.normal_( math.log(1 - self.droprate_init) - math.log(self.droprate_init), 1e-2) if self.use_bias: self.bias.data.fill_(0) def constrain_parameters(self, **kwargs): pass def regularization(self): """ Expected L0 norm under the stochastic gates, takes into account and re-weights also a potential L2 penalty """ p1 = torch.sigmoid(self.logit_p1) weight_decay_ungated = .5 * self.l2_strength * self.weight.pow(2) weight_l2_l0 = torch.sum( (weight_decay_ungated + self.l0_strength) * p1) bias_l2 = (0 if not self.use_bias else torch.sum(.5 * self.l2_strength * self.bias.pow(2))) return -weight_l2_l0 - bias_l2 def sample_weight(self): u = self.floatTensor(self.logit_p1.size()).uniform_(0, 1) p1 = torch.sigmoid(self.logit_p1) mask = p1 > u def cc_to_p1(grad): ratio = p1 / (1 - p1) p1.backward(grad * torch.where(mask, 1 / ratio, ratio)) return grad z = mask.float() z.requires_grad_() z.register_hook(cc_to_p1) return self.weight * z def forward(self, x): return F.conv2d(x, self.sample_weight(), (self.bias if self.use_bias else None), self.stride, self.padding, self.dilation, self.groups) def get_expected_nonzeros(self): expected_gates = torch.sigmoid(self.logit_p1) return expected_gates.sum( dim=tuple(range(1, len(expected_gates.shape)))).detach() def get_inference_nonzeros(self): u = self.floatTensor(self.logit_p1.size()).uniform_(0, 1) inference_gates = torch.sigmoid(self.logit_p1) > u return inference_gates.sum( dim=tuple(range(1, len(inference_gates.shape)))).detach()
class group_relaxed_L0Dense(Module): """Implementation of TFL regularization for the input units of a fully connected layer""" def __init__(self, in_features, out_features, bias=True, lamba=1., beta=4., weight_decay=1., **kwargs): """ :param in_features: input dimensionality :param out_features: output dimensionality :param bias: whether we use bias :param lamba: strength of the TF1 regularization """ super(group_relaxed_L0Dense, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = Parameter(torch.Tensor(in_features, out_features)) self.u = torch.rand(in_features, out_features) self.u = self.u.to('cuda') if bias: self.bias = Parameter(torch.Tensor(out_features)) else: self.register_parameter('bias', None) self.lamba = lamba self.beta = beta self.weight_decay = weight_decay self.floatTensor = torch.FloatTensor if not torch.cuda.is_available( ) else torch.cuda.FloatTensor self.reset_parameters() print(self) def reset_parameters(self): init.kaiming_normal(self.weight, mode='fan_out') if self.bias is not None: self.bias.data.normal_(0, 1e-2) def constrain_parameters(self, **kwargs): #self.weight.data = F.normalize(self.weight.data, p=2, dim=1) m = Hardshrink((2 * self.lamba / self.beta)**(1 / 2)) self.u.data = m(self.weight.data) def grow_beta(self, growth_factor): self.beta = self.beta * growth_factor def _reg_w(self, **kwargs): logpw = -self.beta * torch.sum( 0.5 * self.weight.add(-self.u).pow(2)) - self.lamba * np.sqrt( self.out_features) * torch.sum( torch.pow(torch.sum(self.weight.pow(2), 1), 0.5)) logpb = 0 if self.bias is not None: logpb = -torch.sum(self.weight_decay * .5 * (self.bias.pow(2))) return logpw + logpb def regularization(self): return self._reg_w() def count_zero_u(self): total = np.prod(self.u.size()) zero = total - self.u.nonzero().size(0) return zero def count_zero_w(self): return torch.sum((self.weight.abs() < 1e-5).int()).item() def count_weight(self): return np.prod(self.u.size()) def count_active_neuron(self): return torch.sum( torch.sum(self.weight.abs() / self.out_features, 1) > 1e-5).item() def count_total_neuron(self): return self.in_features def count_expected_flops_and_l0(self): ppos = torch.sum(self.weight.abs() > 0.000001).item() expected_flops = (2 * ppos - 1) * self.out_features expected_l0 = ppos * self.out_features if self.bias is not None: expected_flops += self.out_features expected_l0 += self.out_features return expected_flops, expected_l0 def forward(self, input): output = input.mm(self.weight) if self.bias is not None: output.add_(self.bias.view(1, self.out_features).expand_as(output)) return output def __repr__(self): return self.__class__.__name__+' (' \ + str(self.in_features) + ' -> ' \ + str(self.out_features) + ', lambda: ' \ + str(self.lamba) + ')'
class L0Dense(Module): """Implementation of L0 regularization for the input units of a fully connected layer""" def __init__(self, in_features, out_features, bias=True, weight_decay=1., droprate_init=0.5, temperature=2. / 3., lamba=1., local_rep=False, **kwargs): """ :param in_features: Input dimensionality :param out_features: Output dimensionality :param bias: Whether we use a bias :param weight_decay: Strength of the L2 penalty :param droprate_init: Dropout rate that the L0 gates will be initialized to :param temperature: Temperature of the concrete distribution :param lamba: Strength of the L0 penalty :param local_rep: Whether we will use a separate gate sample per element in the minibatch """ super(L0Dense, self).__init__() self.in_features = in_features self.out_features = out_features self.prior_prec = weight_decay self.weights = Parameter(torch.Tensor(in_features, out_features)) self.qz_loga = Parameter(torch.Tensor(in_features)) self.temperature = temperature self.droprate_init = droprate_init if droprate_init != 0. else 0.5 self.lamba = lamba self.use_bias = False self.local_rep = local_rep if bias: self.bias = Parameter(torch.Tensor(out_features)) self.use_bias = True self.floatTensor = torch.FloatTensor if not torch.cuda.is_available( ) else torch.cuda.FloatTensor self.reset_parameters() print(self) def reset_parameters(self): init.kaiming_normal(self.weights, mode='fan_out') self.qz_loga.data.normal_( math.log(1 - self.droprate_init) - math.log(self.droprate_init), 1e-2) if self.use_bias: self.bias.data.fill_(0) def constrain_parameters(self, **kwargs): self.qz_loga.data.clamp_(min=math.log(1e-2), max=math.log(1e2)) def cdf_qz(self, x): """Implements the CDF of the 'stretched' concrete distribution""" xn = (x - limit_a) / (limit_b - limit_a) logits = math.log(xn) - math.log(1 - xn) return torch.sigmoid(logits * self.temperature - self.qz_loga).clamp( min=epsilon, max=1 - epsilon) def quantile_concrete(self, x): """Implements the quantile, aka inverse CDF, of the 'stretched' concrete distribution""" y = torch.sigmoid((torch.log(x) - torch.log(1 - x) + self.qz_loga) / self.temperature) return y * (limit_b - limit_a) + limit_a def _reg_w(self): """Expected L0 norm under the stochastic gates, takes into account and re-weights also a potential L2 penalty""" logpw_col = torch.sum( -(.5 * self.prior_prec * self.weights.pow(2)) - self.lamba, 1) logpw = torch.sum((1 - self.cdf_qz(0)) * logpw_col) logpb = 0 if not self.use_bias else -torch.sum(.5 * self.prior_prec * self.bias.pow(2)) return logpw + logpb def regularization(self): return self._reg_w() def count_expected_flops_and_l0(self): """Measures the expected floating point operations (FLOPs) and the expected L0 norm""" # dim_in multiplications and dim_in - 1 additions for each output neuron for the weights # + the bias addition for each neuron # total_flops = (2 * in_features - 1) * out_features + out_features ppos = torch.sum(1 - self.cdf_qz(0)) expected_flops = (2 * ppos - 1) * self.out_features expected_l0 = ppos * self.out_features if self.use_bias: expected_flops += self.out_features expected_l0 += self.out_features return expected_flops.item(), expected_l0.item() def get_eps(self, size): """Uniform random numbers for the concrete distribution""" eps = self.floatTensor(size).uniform_(epsilon, 1 - epsilon) eps = Variable(eps) return eps def sample_z(self, batch_size, sample=True): """Sample the hard-concrete gates for training and use a deterministic value for testing""" if sample: eps = self.get_eps(self.floatTensor(batch_size, self.in_features)) z = self.quantile_concrete(eps) return F.hardtanh(z, min_val=0, max_val=1) else: # mode pi = torch.sigmoid(self.qz_loga).view(1, self.in_features).expand( batch_size, self.in_features) return F.hardtanh(pi * (limit_b - limit_a) + limit_a, min_val=0, max_val=1) def sample_weights(self): z = self.quantile_concrete( self.get_eps(self.floatTensor(self.in_features))) mask = F.hardtanh(z, min_val=0, max_val=1) return mask.view(self.in_features, 1) * self.weights def forward(self, input): if self.local_rep or not self.training: z = self.sample_z(input.size(0), sample=self.training) xin = input.mul(z) output = xin.mm(self.weights) else: weights = self.sample_weights() output = input.mm(weights) if self.use_bias: output.add_(self.bias) return output def __repr__(self): s = ( '{name}({in_features} -> {out_features}, droprate_init={droprate_init}, ' 'lamba={lamba}, temperature={temperature}, weight_decay={prior_prec}, ' 'local_rep={local_rep}') if not self.use_bias: s += ', bias=False' s += ')' return s.format(name=self.__class__.__name__, **self.__dict__)
class CLTLayer(nn.Module): def __init__(self, in_features, out_features, alpha=10, isinput=False, isoutput=False): super(CLTLayer, self).__init__() self.n_in = in_features self.n_out = out_features self.isoutput = isoutput self.isinput = isinput self.alpha = alpha self.Mbias = nn.Parameter(torch.Tensor(out_features)) self.M = Parameter(torch.Tensor(out_features, in_features)) self.logS = nn.Parameter(torch.Tensor(out_features, in_features)) self.reset_parameters() def reset_parameters(self): stdv = 1. / math.sqrt(self.M.size(1)) self.M.data.normal_(0, stdv) self.logS.data.zero_().normal_(-9, 0.001) self.Mbias.data.zero_() def KL(self): logS = self.logS.clamp(-11, 11) kl = 0.5 * (self.alpha * (self.M.pow(2) + logS.exp()) - logS).sum() return kl def cdf(self, x, mu=0., sig=1.): return 0.5 * (1 + torch.erf((x - mu) / (sig * math.sqrt(2)))) def pdf(self, x, mu=0., sig=1.): return (1 / (math.sqrt(2 * math.pi) * sig)) * torch.exp(-0.5 * ( (x - mu) / sig).pow(2)) def relu_moments(self, mu, sig): alpha = mu / sig cdf = self.cdf(alpha) pdf = self.pdf(alpha) relu_mean = mu * cdf + sig * pdf relu_var = (sig.pow(2) + mu.pow(2)) * cdf + mu * sig * pdf - relu_mean.pow(2) return relu_mean, relu_var def forward(self, mu_h, var_h): M = self.M var_s = self.logS.clamp(-11, 11).exp() mu_f = F.linear(mu_h, M, self.Mbias) # No input variance if self.isinput: var_f = F.linear(mu_h**2, var_s) else: var_f = F.linear(var_h + mu_h.pow(2), var_s) + F.linear( var_h, M.pow(2)) # compute relu moments if it is not an output layer if not self.isoutput: return self.relu_moments(mu_f, var_f.sqrt()) else: return mu_f, var_f def __repr__(self): return self.__class__.__name__ + ' (' \ + str(self.n_in) + ' -> ' \ + str(self.n_out) \ + f', isinput={self.isinput}, isoutput={self.isoutput})'
class TDConv2d(Module): """Implementation of L0 regularization for the feature maps of a convolutional layer""" def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, dropout=0.5, dropout_botk=0.5, dropout_type="weight", temperature=2.0 / 3.0, weight_decay=1.0, lamba=1.0, local_rep=False, **kwargs): """ :param in_channels: Number of input channels :param out_channels: Number of output channels :param kernel_size: Size of the kernel :param stride: Stride for the convolution :param padding: Padding for the convolution :param dilation: Dilation factor for the convolution :param groups: How many groups we will assume in the convolution :param bias: Whether we will use a bias :param weight_decay: Strength of the L2 penalty """ super(TDConv2d, self).__init__() if in_channels % groups != 0: raise ValueError("in_channels must be divisible by groups") if out_channels % groups != 0: raise ValueError("out_channels must be divisible by groups") self.weight_decay = weight_decay self.floatTensor = (torch.FloatTensor if not torch.cuda.is_available() else torch.cuda.FloatTensor) self.prune_rate = 0 self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = pair(kernel_size) self.stride = pair(stride) self.padding = pair(padding) self.dilation = pair(dilation) self.output_padding = pair(0) self.groups = groups self.weight = Parameter( self.floatTensor(out_channels, in_channels // groups, *self.kernel_size)) if bias: self.bias = Parameter(torch.Tensor(out_channels)) else: self.register_parameter("bias", None) self.dropout = dropout self.dropout_type = dropout_type self.dropout_botk = dropout_botk self.reset_parameters() self.input_shape = None print(self) print(self) def reset_parameters(self): init.kaiming_normal(self.weight, mode="fan_in") if self.bias is not None: self.bias.data.normal_(0, 1e-2) def constrain_parameters(self, thres_std=1.0): pass def _reg_w(self, **kwargs): logpw = -torch.sum(self.weight_decay * 0.5 * (self.weight.pow(2))) logpb = 0 if self.bias is not None: logpb = -torch.sum(self.weight_decay * 0.5 * (self.bias.pow(2))) return logpw + logpb def regularization(self): return self._reg_w() def count_expected_flops_and_l0(self): ppos = self.out_channels n = self.kernel_size[0] * self.kernel_size[ 1] * self.in_channels # vector_length flops_per_instance = n + (n - 1 ) # (n: multiplications and n-1: additions) num_instances_per_filter = ( (self.input_shape[1] - self.kernel_size[0] + 2 * self.padding[0]) / self.stride[0]) + 1 # for rows num_instances_per_filter *= ( (self.input_shape[2] - self.kernel_size[1] + 2 * self.padding[1]) / self.stride[1]) + 1 # multiplying with cols flops_per_filter = num_instances_per_filter * flops_per_instance expected_flops = flops_per_filter * ppos # multiply with number of filters expected_l0 = n * ppos if self.bias is not None: # since the gate is applied to the output we also reduce the bias computation expected_flops += num_instances_per_filter * ppos expected_l0 += ppos return expected_flops, expected_l0 def targeted_dropout(self, w): drop_rate = self.dropout targ_perc = self.dropout_botk # print("w_orig: ", w) if self.dropout == 0: return w cuda0 = torch.device("cuda:0") if self.dropout_type == "weight": w_shape = w.size() w = w.view(w_shape[0], -1) norm = w.abs() idx = int(targ_perc * float(w.size()[1])) norm_sorted, _ = norm.sort(dim=1) threshold = norm_sorted[:, idx] mask = norm < threshold[:, None] if not self.training: w = (1.0 - mask.float()) * w w = w.view(w_shape) return w dropout_mask = torch.rand(w.size(), device=cuda0) < drop_rate mask = dropout_mask & mask w = (1.0 - mask.float()) * w w = w.view(w_shape) return w if self.dropout_type == "unit": w_shape = w.size() w = w.view(w_shape[0], -1) idx = int(targ_perc * float(w.size()[0])) norm = w.norm(p=2, dim=1) norm_sorted, _ = norm.sort(dim=0) #print("norm_sorted:", norm_sorted) threshold = norm_sorted[idx] #print("thresh:", threshold) mask = norm < threshold #print("mask:", mask, mask.size()) mask = mask.repeat(1, w.size()[1]).view(w.size()[0], -1) #print(mask.size(), w.size(), "yolo") dropout_mask = torch.rand(w.size(), device=cuda0) < drop_rate mask = dropout_mask & mask w = (1.0 - mask.float()) * w w = w.view(w_shape) return w def prune(self, botk): self.prune_rate = botk def prune_weights(self, w): w_shape = w.size() w = w.view(-1, w_shape[-1]) norm = w.abs() idx = int(self.prune_rate * float(w.size()[0])) norm_sorted, _ = norm.sort(dim=0) threshold = norm_sorted[idx:idx + 1] mask = norm >= threshold w = mask.float() * w w = w.view(w_shape) return torch.nn.Parameter(w) def forward(self, input_): if self.input_shape is None: self.input_shape = input_.size() weight = self.targeted_dropout(self.weight) if self.prune_rate > 0.0: weight = self.prune_weights(weight) output = F.conv2d(input_, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) return output def __repr__(self): s = ( "{name}({in_channels}, {out_channels}, kernel_size={kernel_size}, stride={stride}, " "dropout={dropout}, dropout_botk={dropout_botk}, ") if self.padding != (0, ) * len(self.padding): s += ", padding={padding}" if self.dilation != (1, ) * len(self.dilation): s += ", dilation={dilation}" if self.output_padding != (0, ) * len(self.output_padding): s += ", output_padding={output_padding}" if self.groups != 1: s += ", groups={groups}" s += ")" return s.format(name=self.__class__.__name__, **self.__dict__)
class sparse_group_lasso_Conv2d(Module): """Implementation of TF1 regularization for the feature maps of a convolutional layer""" def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, lamba=1., weight_decay=1., **kwargs): """ :param in_channels: Number of input channels :param out_channels: Number of output channels :param kernel_size: size of the kernel :param stride: stride for the convolution :param padding: padding for the convolution :param dilation: dilation factor for the convolution :param groups: how many groups we will assume in the convolution :param bias: whether we will use a bias :param lamba: strength of the TFL regularization """ super(sparse_group_lasso_Conv2d, self).__init__() self.floatTensor = torch.FloatTensor if not torch.cuda.is_available( ) else torch.cuda.FloatTensor self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = pair(kernel_size) self.stride = pair(stride) self.padding = pair(padding) self.dilation = pair(dilation) self.output_padding = pair(0) self.groups = groups self.lamba = lamba self.weight_decay = weight_decay self.weight = Parameter( torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) if bias: self.bias = Parameter(torch.Tensor(out_channels)) else: self.register_parameter('bias', None) self.reset_parameters() self.input_shape = None print(self) def reset_parameters(self): init.kaiming_normal(self.weight, mode='fan_in') if self.bias is not None: self.bias.data.normal_(0, 1e-2) def constrain_parameters(self, thres_std=1.): pass def _reg_w(self, **kwargs): logpw = -self.lamba * np.sqrt( self.in_channels * self.kernel_size[0] * self.kernel_size[1]) * torch.sum( torch.pow(torch.sum(self.weight.pow(2), 3).sum(2).sum(1), 0.5)) - torch.sum(self.lamba * self.weight.abs()) logpb = 0 if self.bias is not None: logpb = -torch.sum(self.weight_decay * .5 * (self.bias.pow(2))) return logpw + logpb def regularization(self): return self._reg_w() def count_zero_w(self): return torch.sum((self.weight.abs() < 1e-5).int()).item() def count_weight(self): return np.prod(self.weight.size()) def count_active_neuron(self): return torch.sum((torch.sum(self.weight.abs(), 3).sum(2).sum(1) / (self.in_channels * self.kernel_size[0] * self.kernel_size[1])) > 1e-5).item() def count_total_neuron(self): return self.out_channels def count_expected_flops_and_l0(self): #ppos = self.out_channels ppos = torch.sum( torch.sum(self.weight.abs(), 3).sum(2).sum(1) > 0.001).item() n = self.kernel_size[0] * self.kernel_size[1] * self.in_channels flops_per_instance = n + (n - 1) num_instances_per_filter = ( (self.input_shape[1] - self.kernel_size[0] + 2 * self.padding[0]) / self.stride[0]) + 1 num_instances_per_filter *= ( (self.input_shape[2] - self.kernel_size[1] + 2 * self.padding[1]) / self.stride[1]) + 1 flops_per_filter = num_instances_per_filter * flops_per_instance expected_flops = flops_per_filter * ppos expected_l0 = n * ppos if self.bias is not None: expected_flops += num_instances_per_filter * ppos expected_l0 += ppos return expected_flops, expected_l0 def forward(self, input_): if self.input_shape is None: self.input_shape = input_.size() output = F.conv2d(input_, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) return output def __repr__(self): s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size} ' ', stride={stride}') if self.padding != (0, ) * len(self.padding): s += ', padding={padding}' if self.dilation != (1, ) * len(self.dilation): s += ', dilation={dilation}' if self.output_padding != (0, ) * len(self.output_padding): s += ', output_padding={output_padding}' if self.groups != 1: s += ', groups={groups}' if self.bias is None: s += ', bias=False' s += ')' return s.format(name=self.__class__.__name__, **self.__dict__)
class _ConvNdGroupNJ(BayesianLayers): """Convolutional Group Normal-Jeffrey's layers (aka Group Variational Dropout). References: [1] Kingma, Diederik P., Tim Salimans, and Max Welling. "Variational dropout and the local reparameterization trick." NIPS (2015). [2] Molchanov, Dmitry, Arsenii Ashukha, and Dmitry Vetrov. "Variational Dropout Sparsifies Deep Neural Networks." ICML (2017). [3] Louizos, Christos, Karen Ullrich, and Max Welling. "Bayesian Compression for Deep Learning." NIPS (2017). """ def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, bias, init_weight, init_bias, cuda=False, clip_var=None): super(_ConvNdGroupNJ, self).__init__() if in_channels % groups != 0: raise ValueError('in_channels must be divisible by groups') if out_channels % groups != 0: raise ValueError('out_channels must be divisible by groups') self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size self.stride = stride self.padding = padding self.dilation = dilation self.transposed = transposed self.output_padding = output_padding self.groups = groups self.cuda = cuda self.clip_var = clip_var self.deterministic = False # flag is used for compressed inference if transposed: self.weight_mu = Parameter(torch.Tensor( in_channels, out_channels // groups, *kernel_size)) self.weight_logvar = Parameter(torch.Tensor( in_channels, out_channels // groups, *kernel_size)) else: self.weight_mu = Parameter(torch.Tensor( out_channels, in_channels // groups, *kernel_size)) self.weight_logvar = Parameter(torch.Tensor( out_channels, in_channels // groups, *kernel_size)) self.bias_mu = Parameter(torch.Tensor(out_channels)) self.bias_logvar = Parameter(torch.Tensor(out_channels)) self.z_mu = Parameter(torch.Tensor(self.out_channels)) self.z_logvar = Parameter(torch.Tensor(self.out_channels)) self.reset_parameters(init_weight, init_bias) # activations for kl self.sigmoid = nn.Sigmoid() self.softplus = nn.Softplus() # numerical stability param self.epsilon = 1e-8 def reset_parameters(self, init_weight, init_bias): # init means n = self.in_channels for k in self.kernel_size: n *= k stdv = 1. / math.sqrt(n) # init means if init_weight is not None: self.weight_mu.data = init_weight else: self.weight_mu.data.uniform_(-stdv, stdv) if init_bias is not None: self.bias_mu.data = init_bias else: self.bias_mu.data.fill_(0) # inti z self.z_mu.data.normal_(1, 1e-2) # init logvars self.z_logvar.data.normal_(-9, 1e-2) self.weight_logvar.data.normal_(-9, 1e-2) self.bias_logvar.data.normal_(-9, 1e-2) def clip_variances(self): if self.clip_var: self.weight_logvar.data.clamp_(max=math.log(self.clip_var)) self.bias_logvar.data.clamp_(max=math.log(self.clip_var)) def get_log_dropout_rates(self): log_alpha = self.z_logvar - torch.log(self.z_mu.pow(2) + self.epsilon) return log_alpha def compute_posterior_params(self): weight_var, z_var = self.weight_logvar.exp(), self.z_logvar.exp() print("self.z_mu.pow(2): ", self.z_mu.pow(2).size()) print("weight_var: ", weight_var.size()) print("z_var: ", z_var.size()) print("self.weight_mu.pow(2): ", self.weight_mu.pow(2).size()) print("weight_var: ", weight_var.size()) part1 = self.z_mu.pow(2) * weight_var part2 = z_var * self.weight_mu.pow(2) part3 = z_var * weight_var self.post_weight_var = part1 + part2 + part3 self.post_weight_mu = self.weight_mu * self.z_mu print("post_weight_mu: ", self.post_weight_mu.size()) print("post_weight_var: ", self.post_weight_var.size()) return self.post_weight_mu, self.post_weight_var def kl_divergence(self): # KL(q(z)||p(z)) # we use the kl divergence approximation given by [2] Eq.(14) k1, k2, k3 = 0.63576, 1.87320, 1.48695 log_alpha = self.get_log_dropout_rates() KLD = -torch.sum(k1 * self.sigmoid(k2 + k3 * log_alpha) - 0.5 * self.softplus(-log_alpha) - k1) # KL(q(w|z)||p(w|z)) # we use the kl divergence given by [3] Eq.(8) KLD_element = -self.weight_logvar + 0.5 * (self.weight_logvar.exp().pow(2) + self.weight_mu.pow(2)) - 0.5 KLD += torch.sum(KLD_element) # KL bias KLD_element = -self.bias_logvar + 0.5 * (self.bias_logvar.exp().pow(2) + self.bias_mu.pow(2)) - 0.5 KLD += torch.sum(KLD_element) return KLD def __repr__(self): s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}' ', stride={stride}') if self.padding != (0,) * len(self.padding): s += ', padding={padding}' if self.dilation != (1,) * len(self.dilation): s += ', dilation={dilation}' if self.output_padding != (0,) * len(self.output_padding): s += ', output_padding={output_padding}' if self.groups != 1: s += ', groups={groups}' if self.bias is None: s += ', bias=False' s += ')' return s.format(name=self.__class__.__name__, **self.__dict__)
class group_relaxed_SCAD_Dense(Module): """Implementation of TFL regularization for the input units of a fully connected layer""" def __init__(self, in_features, out_features, bias=True, lamba=1., alpha = 3.7, beta = 4.0, weight_decay=1., **kwargs): """ :param in_features: input dimensionality :param out_features: output dimensionality :param bias: whether we use bias :param lamba: strength of the TF1 regularization """ super(group_relaxed_SCAD_Dense,self).__init__() self.in_features = in_features self.out_features = out_features self.weight = Parameter(torch.Tensor(in_features, out_features)) self.u = torch.rand(in_features, out_features) self.u = self.u.to('cuda') if bias: self.bias = Parameter(torch.Tensor(out_features)) else: self.register_parameter('bias', None) self.lamba = lamba self.alpha = alpha self.beta = beta self.lamba1 = self.lamba/self.beta self.weight_decay = weight_decay self.floatTensor = torch.FloatTensor if not torch.cuda.is_available() else torch.cuda.FloatTensor self.reset_parameters() print(self) def reset_parameters(self): init.kaiming_normal(self.weight, mode='fan_out') if self.bias is not None: self.bias.data.normal_(0,1e-2) def constrain_parameters(self, **kwargs): self.u = self.weight.clone() s = Softshrink(self.lamba1) #shrinkage on values with absolute value less than 2*lamba1 shrink_value = s(self.weight.data) self.u[self.weight.abs()<=2*self.lamba1] = shrink_value[self.weight.abs()<=2*self.lamba1] #modify values whose absolute values are between 2*lamba1 and alpha*lamba1 modify_weight = self.weight.data modify_weight = ((self.alpha - 1)*modify_weight-modify_weight.sign()*(3.7*self.lamba1))/(self.alpha -2) self.u[(self.weight.abs()>2*self.lamba1) & (self.weight.abs()<=self.alpha*self.lamba1)] = modify_weight[(self.weight.abs()>2*self.lamba1) & (self.weight.abs()<=self.alpha*self.lamba1)] def grow_beta(self, growth_factor): self.beta = self.beta*growth_factor self.lamba1 = self.lamba/self.beta def _reg_w(self, **kwargs): logpw = -self.beta*torch.sum(0.5*self.weight.add(-self.u).pow(2))-self.lamba*np.sqrt(self.out_features)*torch.sum(torch.pow(torch.sum(self.weight.pow(2),1),0.5)) logpb = 0 if self.bias is not None: logpb = - torch.sum(self.weight_decay * .5 * (self.bias.pow(2))) return logpw + logpb def regularization(self): return self._reg_w() def count_zero_u(self): total = np.prod(self.u.size()) zero = total - self.u.nonzero().size(0) return zero def count_zero_w(self): return torch.sum((self.weight.abs()<1e-5).int()).item() def count_weight(self): return np.prod(self.u.size()) def count_active_neuron(self): return torch.sum(torch.sum(self.weight.abs()/self.out_features,1)>1e-5).item() def count_total_neuron(self): return self.in_features def count_expected_flops_and_l0(self): ppos = torch.sum(self.weight.abs()>0.000001).item() expected_flops = (2*ppos-1)*self.out_features expected_l0 = ppos*self.out_features if self.bias is not None: expected_flops += self.out_features expected_l0 += self.out_features return expected_flops, expected_l0 def forward(self, input): output = input.mm(self.weight) if self.bias is not None: output.add_(self.bias.view(1, self.out_features).expand_as(output)) return output def __repr__(self): return self.__class__.__name__+' (' \ + str(self.in_features) + ' -> ' \ + str(self.out_features) + ', lambda: ' \ + str(self.lamba) + ')'
class MAPConv2d(Module): def __init__( self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, weight_decay=1.0, **kwargs ): super(MAPConv2d, self).__init__() self.weight_decay = weight_decay self.floatTensor = ( torch.FloatTensor if not torch.cuda.is_available() else torch.cuda.FloatTensor ) self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = pair(kernel_size) self.stride = pair(stride) self.padding = pair(padding) self.dilation = pair(dilation) self.output_padding = pair(0) self.groups = groups self.weight = Parameter( torch.Tensor(out_channels, in_channels // groups, *self.kernel_size) ) if bias: self.bias = Parameter(torch.Tensor(out_channels)) else: self.register_parameter("bias", None) self.reset_parameters() self.input_shape = None print(self) def reset_parameters(self): init.kaiming_normal(self.weight, mode="fan_in") if self.bias is not None: self.bias.data.normal_(0, 1e-2) def constrain_parameters(self, thres_std=1.0): pass def _reg_w(self, **kwargs): logpw = -torch.sum(self.weight_decay * 0.5 * (self.weight.pow(2))) logpb = 0 if self.bias is not None: logpb = -torch.sum(self.weight_decay * 0.5 * (self.bias.pow(2))) return logpw + logpb def regularization(self): return self._reg_w() def count_expected_flops_and_l0(self): ppos = self.out_channels n = self.kernel_size[0] * self.kernel_size[1] * self.in_channels # vector_length flops_per_instance = n + (n - 1) # (n: multiplications and n-1: additions) num_instances_per_filter = ( (self.input_shape[1] - self.kernel_size[0] + 2 * self.padding[0]) / self.stride[0] ) + 1 # for rows num_instances_per_filter *= ( (self.input_shape[2] - self.kernel_size[1] + 2 * self.padding[1]) / self.stride[1] ) + 1 # multiplying with cols flops_per_filter = num_instances_per_filter * flops_per_instance expected_flops = flops_per_filter * ppos # multiply with number of filters expected_l0 = n * ppos if self.bias is not None: # since the gate is applied to the output we also reduce the bias computation expected_flops += num_instances_per_filter * ppos expected_l0 += ppos return expected_flops, expected_l0 def forward(self, input_): if self.input_shape is None: self.input_shape = input_.size() output = F.conv2d( input_, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups ) return output def __repr__(self): s = ( "{name}({in_channels}, {out_channels}, kernel_size={kernel_size} " ", stride={stride}, weight_decay={weight_decay}" ) if self.padding != (0,) * len(self.padding): s += ", padding={padding}" if self.dilation != (1,) * len(self.dilation): s += ", dilation={dilation}" if self.output_padding != (0,) * len(self.output_padding): s += ", output_padding={output_padding}" if self.groups != 1: s += ", groups={groups}" if self.bias is None: s += ", bias=False" s += ")" return s.format(name=self.__class__.__name__, **self.__dict__)
class L0Conv2d(Module): """Implementation of L0 regularization for the feature maps of a convolutional layer""" def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, droprate_init=0.5, temperature=2. / 3., weight_decay=1., lamba=1., local_rep=False, **kwargs): """ :param in_channels: Number of input channels :param out_channels: Number of output channels :param kernel_size: Size of the kernel :param stride: Stride for the convolution :param padding: Padding for the convolution :param dilation: Dilation factor for the convolution :param groups: How many groups we will assume in the convolution :param bias: Whether we will use a bias :param droprate_init: Dropout rate that the L0 gates will be initialized to :param temperature: Temperature of the concrete distribution :param weight_decay: Strength of the L2 penalty :param lamba: Strength of the L0 penalty :param local_rep: Whether we will use a separate gate sample per element in the minibatch """ super(L0Conv2d, self).__init__() if in_channels % groups != 0: raise ValueError('in_channels must be divisible by groups') if out_channels % groups != 0: raise ValueError('out_channels must be divisible by groups') self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = pair(kernel_size) self.stride = pair(stride) self.padding = pair(padding) self.dilation = pair(dilation) self.output_padding = pair(0) self.groups = groups self.prior_prec = weight_decay self.lamba = lamba self.droprate_init = droprate_init if droprate_init != 0. else 0.5 self.temperature = temperature self.floatTensor = torch.FloatTensor if not torch.cuda.is_available( ) else torch.cuda.FloatTensor self.use_bias = False self.weights = Parameter( torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) self.qz_loga = Parameter(torch.Tensor(out_channels)) self.dim_z = out_channels self.input_shape = None self.local_rep = local_rep if bias: self.bias = Parameter(torch.Tensor(out_channels)) self.use_bias = True self.reset_parameters() print(self) def reset_parameters(self): init.kaiming_normal(self.weights, mode='fan_in') self.qz_loga.data.normal_( math.log(1 - self.droprate_init) - math.log(self.droprate_init), 1e-2) if self.use_bias: self.bias.data.fill_(0) def constrain_parameters(self, **kwargs): self.qz_loga.data.clamp_(min=math.log(1e-2), max=math.log(1e2)) def cdf_qz(self, x): """Implements the CDF of the 'stretched' concrete distribution""" xn = (x - limit_a) / (limit_b - limit_a) logits = math.log(xn) - math.log(1 - xn) return torch.sigmoid(logits * self.temperature - self.qz_loga).clamp( min=epsilon, max=1 - epsilon) def quantile_concrete(self, x): """Implements the quantile, aka inverse CDF, of the 'stretched' concrete distribution""" y = torch.sigmoid((torch.log(x) - torch.log(1 - x) + self.qz_loga) / self.temperature) return y * (limit_b - limit_a) + limit_a def _reg_w(self): """Expected L0 norm under the stochastic gates, takes into account and re-weights also a potential L2 penalty""" q0 = self.cdf_qz(0) logpw_col = torch.sum( -(.5 * self.prior_prec * self.weights.pow(2)) - self.lamba, 3).sum(2).sum(1) logpw = torch.sum((1 - q0) * logpw_col) logpb = 0 if not self.use_bias else -torch.sum( (1 - q0) * (.5 * self.prior_prec * self.bias.pow(2) - self.lamba)) return logpw + logpb def regularization(self): return self._reg_w() def count_expected_flops_and_l0(self): """Measures the expected floating point operations (FLOPs) and the expected L0 norm""" ppos = torch.sum(1 - self.cdf_qz(0)) n = self.kernel_size[0] * self.kernel_size[ 1] * self.in_channels # vector_length flops_per_instance = n + (n - 1 ) # (n: multiplications and n-1: additions) num_instances_per_filter = ( (self.input_shape[1] - self.kernel_size[0] + 2 * self.padding[0]) / self.stride[0]) + 1 # for rows num_instances_per_filter *= ( (self.input_shape[2] - self.kernel_size[1] + 2 * self.padding[1]) / self.stride[1]) + 1 # multiplying with cols flops_per_filter = num_instances_per_filter * flops_per_instance expected_flops = flops_per_filter * ppos # multiply with number of filters expected_l0 = n * ppos if self.use_bias: # since the gate is applied to the output we also reduce the bias computation expected_flops += num_instances_per_filter * ppos expected_l0 += ppos return expected_flops.data[0], expected_l0.data[0] def get_eps(self, size): """Uniform random numbers for the concrete distribution""" eps = self.floatTensor(size).uniform_(epsilon, 1 - epsilon) eps = Variable(eps) return eps def sample_z(self, batch_size, sample=True): """Sample the hard-concrete gates for training and use a deterministic value for testing""" if sample: eps = self.get_eps(self.floatTensor(batch_size, self.dim_z)) z = self.quantile_concrete(eps).view(batch_size, self.dim_z, 1, 1) return F.hardtanh(z, min_val=0, max_val=1) else: # mode pi = torch.sigmoid(self.qz_loga).view(1, self.dim_z, 1, 1) return F.hardtanh(pi * (limit_b - limit_a) + limit_a, min_val=0, max_val=1) def sample_weights(self): z = self.quantile_concrete(self.get_eps(self.floatTensor( self.dim_z))).view(self.dim_z, 1, 1, 1) return F.hardtanh(z, min_val=0, max_val=1) * self.weights def forward(self, input_): if self.input_shape is None: self.input_shape = input_.size() b = None if not self.use_bias else self.bias if self.local_rep or not self.training: output = F.conv2d(input_, self.weights, b, self.stride, self.padding, self.dilation, self.groups) z = self.sample_z(output.size(0), sample=self.training) return output.mul(z) else: weights = self.sample_weights() output = F.conv2d(input_, weights, None, self.stride, self.padding, self.dilation, self.groups) return output def __repr__(self): s = ( '{name}({in_channels}, {out_channels}, kernel_size={kernel_size}, stride={stride}, ' 'droprate_init={droprate_init}, temperature={temperature}, prior_prec={prior_prec}, ' 'lamba={lamba}, local_rep={local_rep}') if self.padding != (0, ) * len(self.padding): s += ', padding={padding}' if self.dilation != (1, ) * len(self.dilation): s += ', dilation={dilation}' if self.output_padding != (0, ) * len(self.output_padding): s += ', output_padding={output_padding}' if self.groups != 1: s += ', groups={groups}' if not self.use_bias: s += ', bias=False' s += ')' return s.format(name=self.__class__.__name__, **self.__dict__)
class MAPDense(Module): def __init__(self, in_features, out_features, bias=True, weight_decay=1.0, **kwargs): super(MAPDense, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = Parameter(torch.Tensor(in_features, out_features)) self.weight_decay = weight_decay if bias: self.bias = Parameter(torch.Tensor(out_features)) else: self.register_parameter("bias", None) self.floatTensor = ( torch.FloatTensor if not torch.cuda.is_available() else torch.cuda.FloatTensor ) self.reset_parameters() print(self) def reset_parameters(self): init.kaiming_normal(self.weight, mode="fan_out") if self.bias is not None: self.bias.data.normal_(0, 1e-2) def constrain_parameters(self, **kwargs): pass def _reg_w(self, **kwargs): logpw = -torch.sum(self.weight_decay * 0.5 * (self.weight.pow(2))) logpb = 0 if self.bias is not None: logpb = -torch.sum(self.weight_decay * 0.5 * (self.bias.pow(2))) return logpw + logpb def regularization(self): return self._reg_w() def count_expected_flops_and_l0(self): # dim_in multiplications and dim_in - 1 additions for each output neuron for the weights # + the bias addition for each neuron # total_flops = (2 * in_features - 1) * out_features + out_features expected_flops = (2 * self.in_features - 1) * self.out_features expected_l0 = self.in_features * self.out_features if self.bias is not None: expected_flops += self.out_features expected_l0 += self.out_features return expected_flops, expected_l0 def forward(self, input): output = input.mm(self.weight) if self.bias is not None: output.add_(self.bias.view(1, self.out_features).expand_as(output)) return output def __repr__(self): return ( self.__class__.__name__ + " (" + str(self.in_features) + " -> " + str(self.out_features) + ", weight_decay: " + str(self.weight_decay) + ")" )
class L0Dense(nn.Module): """ Implementation of L0 regularization for the input units of a fully connected layer """ def __init__(self, feature, embed_dim, weight_decay=0.0005, droprate=0.5, bias=False, temperature=2. / 3., lamda=1., local_rep=False, **kwargs): """ feature: input dimension embed_dim: output dimension bias: whether use a bias weight_decay: strength of the L2 penalty droprate: dropout rate that the L0 gates will be initialized to temperature: temperature of the concrete distribution lamda: strength of the L0 penalty local_rep: whether use a separate gate sample per element in the minibatch """ super(L0Dense, self).__init__() self.feature = feature self.embed_dim = embed_dim self.prior_prec = weight_decay self.temperature = temperature self.droprate = droprate self.lamda = lamda self.use_bias = bias self.local_rep = local_rep self.weights = Parameter(torch.Tensor(feature, embed_dim)) # 一行 self.qz_loga = Parameter(torch.Tensor(feature)) if bias: self.bias = Parameter(torch.Tensor(embed_dim)) self.floatTensor = torch.FloatTensor if not torch.cuda.is_available( ) else torch.cuda.FloatTensor self.reset_parameters() def reset_parameters(self): init.kaiming_normal_(self.weights, mode='fan_out') self.qz_loga.data.normal_( math.log(1 - self.droprate) - math.log(self.droprate), 1e-2) if self.use_bias: self.bias.data.fill_(0) def constrain_parameters(self, **kwargs): self.qz_loga.data.clamp_(min=math.log(1e-2), max=math.log(1e2)) def cdf_qz(self, x): # Implements CDF of the stretched concrete distribution xn = (x - limit_a) / (limit_b - limit_a) logits = math.log(xn) - math.log(1 - xn) return torch.sigmoid(logits * self.temperature - self.qz_loga).clamp( min=epsilon, max=1 - epsilon) def quantile_concrete(self, x): # Implements the quantile of stretched concrete distribution y = torch.sigmoid((torch.log(x) - torch.log(1 - x) + self.qz_loga) / self.temperature) return y * (limit_b - limit_a) + limit_a def _reg_w(self): # Expected L0 norm under the stochastic gates logpw_col = torch.sum( -(.5 * self.prior_prec * self.weights.pow(2)) - self.lamda, 1) logpw = torch.sum((1 - self.cdf_qz(0)) * logpw_col) logpb = 0 if not self.use_bias else -torch.sum(.5 * self.prior_prec * self.bias.pow(2)) return logpw + logpb def regularization(self): return self._reg_w() def get_eps(self, size): # Uniform random numbers for the concrete distribution eps = self.floatTensor(size).uniform_(epsilon, 1 - epsilon) eps = Variable(eps) return eps def sample_z(self, batch_size, sample=True): # Sample the hard-concrete gates for training and use a deterministic value for testing # training if sample: eps = self.get_eps(self.floatTensor(batch_size, self.feature)) z = self.quantile_concrete(eps) return F.hardtanh(z, min_val=0, max_val=1) # testing else: pi = torch.sigmoid(self.qz_loga).view(1, self.feature).expand( batch_size, self.feature) return F.hardtanh(pi * (limit_b - limit_a) + limit_a, min_val=0, max_val=1) def sample_weights(self): z = self.quantile_concrete(self.get_eps(self.floatTensor( self.feature))) mask = F.hardtanh(z, min_val=0, max_val=1) return mask.view(self.feature, 1) * self.weights def forward(self, input): if self.local_rep or not self.training: z = self.sample_z(input.size(0), sample=self.training) xin = input.mul(z) output = xin.mm(self.weights) else: weights = self.sample_weights() output = input.mm(weights) if self.use_bias: output.add_(self.bias) return output
class HardConcreteGatedLinear(Module): """ Linear layer with stochastic connections, as in https://arxiv.org/abs/1712.01312 """ def __init__(self, in_features, out_features, l0_strength=1., l2_strength=1., bias=True, learn_weight=True, droprate_init=0.5, temperature=(2 / 3), **kwargs): """ :param in_features: Input dimensionality :param out_features: Output dimensionality :param bias: Whether we use a bias :param l2_strength: Strength of the L2 penalty :param droprate_init: Dropout rate that the L0 gates will be initialized to :param temperature: Temperature of the concrete distribution :param l0_strength: Strength of the L0 penalty """ super(HardConcreteGatedLinear, self).__init__() self.in_features = in_features self.out_features = out_features self.l0_strength = l0_strength self.l2_strength = l2_strength self.floatTensor = (torch.FloatTensor if not torch.cuda.is_available() else torch.cuda.FloatTensor) weight = torch.Tensor(out_features, in_features) if learn_weight: self.weight = Parameter(weight) else: self.register_buffer("weight", weight) self.loga = Parameter(torch.Tensor(out_features, in_features)) self.temperature = temperature self.droprate_init = droprate_init if droprate_init != 0. else 0.5 if bias: bias = torch.Tensor(out_features) if learn_weight: self.bias = Parameter(bias) else: self.register_buffer("bias", bias) self.use_bias = True else: self.use_bias = False self.reset_parameters() def reset_parameters(self): init.kaiming_normal_(self.weight, mode="fan_out") self.loga.data.normal_( math.log(1 - self.droprate_init) - math.log(self.droprate_init), 1e-2) if self.use_bias: self.bias.data.fill_(0) def constrain_parameters(self, **kwargs): self.loga.data.clamp_(min=math.log(1e-2), max=math.log(1e2)) def cdf_qz(self, x): """Implements the CDF of the 'stretched' concrete distribution""" xn = (x - LIMIT_A) / (LIMIT_B - LIMIT_A) logits = math.log(xn) - math.log(1 - xn) return torch.sigmoid(logits * self.temperature - self.loga).clamp( min=EPSILON, max=1 - EPSILON) def quantile_concrete(self, x): """ Implements the quantile, aka inverse CDF, of the 'stretched' concrete distribution """ y = torch.sigmoid( (torch.log(x) - torch.log(1 - x) + self.loga) / self.temperature) return y * (LIMIT_B - LIMIT_A) + LIMIT_A def weight_size(self): return self.weight.size() def regularization(self): """ Expected L0 norm under the stochastic gates, takes into account and re-weights also a potential L2 penalty """ weight_decay_ungated = .5 * self.l2_strength * self.weight.pow(2) weight_l2_l0 = torch.sum( (weight_decay_ungated + self.l0_strength) * (1 - self.cdf_qz(0))) bias_l2 = (0 if not self.use_bias else torch.sum(.5 * self.l2_strength * self.bias.pow(2))) return weight_l2_l0 + bias_l2 def count_inference_flops(self): # For each unit, multiply with its n inputs then do n - 1 additions. # To capture the -1, subtract it, but only in cases where there is at # least one weight. nz_by_unit = self.get_inference_nonzeros() multiplies = torch.sum(nz_by_unit) adds = multiplies - torch.sum(nz_by_unit > 0) return multiplies.item(), adds.item() def count_expected_flops_and_l0(self): """ Measures the expected floating point operations (FLOPs) and the expected L0 norm Copied from the original L0 paper code """ # dim_in multiplications and dim_in - 1 additions for each output unit # for the weights # + the bias addition for each unit # total_flops = (2 * in_features - 1) * out_features + out_features ppos = torch.sum(1 - self.cdf_qz(0)) expected_flops = (2 * ppos - 1) * self.out_features expected_l0 = ppos * self.out_features if self.use_bias: expected_flops += self.out_features expected_l0 += self.out_features return expected_flops.data[0], expected_l0.data[0] def get_eps(self, size): """Uniform random numbers for the concrete distribution""" eps = self.floatTensor(size).uniform_(EPSILON, 1 - EPSILON) eps = Variable(eps) return eps def sample_weight(self): if self.training: z = self.quantile_concrete( self.get_eps(self.floatTensor(self.loga.size()))) mask = F.hardtanh(z, min_val=0, max_val=1) else: pi = torch.sigmoid(self.loga) mask = F.hardtanh(pi * (LIMIT_B - LIMIT_A) + LIMIT_A, min_val=0, max_val=1) return mask * self.weight def forward(self, x): return F.linear(x, self.sample_weight(), (self.bias if self.use_bias else None)) def get_expected_nonzeros(self): expected_gates = 1 - self.cdf_qz(0) return expected_gates.sum( dim=tuple(range(1, len(expected_gates.shape)))).detach() def get_inference_nonzeros(self): inference_gates = F.hardtanh(torch.sigmoid(self.loga) * (LIMIT_B - LIMIT_A) + LIMIT_A, min_val=0, max_val=1) return (inference_gates > 0).sum( dim=tuple(range(1, len(inference_gates.shape)))).detach() def __repr__(self): s = ("{name}({in_features} -> {out_features}, " "droprate_init={droprate_init}, l0_strength={l0_strength}, " "temperature={temperature}, l2_strength={l2_strength}, ") if not self.use_bias: s += ", bias=False" s += ")" return s.format(name=self.__class__.__name__, **self.__dict__)
class Conv2d(nn.Module): def __init__( self, in_channels, out_channels, kernel_width=(1, 1), stride=(1, 1), dilation=(1, 1), g_init=1.0, bias_init=0.1, causal=False, activation=None, ): super(Conv2d, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.kernel_width = kernel_width self.stride = stride self.dilation = dilation self.causal = causal self.activation = activation self.generating = False self.generating_reset = True self._weight = None self._input_cache = None self.padding = tuple(d * (w-1)//2 for w, d in zip(kernel_width, dilation)) self.bias = Parameter(torch.Tensor(out_channels)) self.weight_v = Parameter(torch.Tensor(out_channels, in_channels, *kernel_width)) self.weight_g = Parameter(torch.Tensor(out_channels)) if causal: if any(w % 2 == 0 for w in kernel_width): raise HyperparameterError(f"Even kernel width incompatible with causal convolution: {kernel_width}") if kernel_width == (1, 3): # make common case explicit mask = torch.Tensor([1., 1., 0.]) elif kernel_width[0] == 1: mask = torch.ones(kernel_width) mask[0, kernel_width[1] // 2 + 1:] = 0 else: mask = torch.ones(kernel_width) mask[kernel_width[0] // 2, kernel_width[1] // 2:] = 0 mask[kernel_width[0] // 2 + 1:, :] = 0 mask = mask.view(1, 1, *kernel_width) self.register_buffer('mask', mask) else: self.register_buffer('mask', None) self.reset_parameters(g_init=g_init, bias_init=bias_init) def reset_parameters(self, v_mean=0., v_std=0.05, g_init=1.0, bias_init=0.1): nn.init.normal_(self.weight_v, mean=v_mean, std=v_std) nn.init.constant_(self.weight_g, val=g_init) nn.init.constant_(self.bias, val=bias_init) def generate(self, mode=True): self.generating = mode self.generating_reset = True self._weight = None self._input_cache = None return self def weight_costs(self): return ( self.weight_v.pow(2).sum(), self.weight_g.pow(2).sum(), self.bias.pow(2).sum() ) @property def weight(self): shape = (self.out_channels, 1, 1, 1) weight = l2_norm_except_dim(self.weight_v, 0) * self.weight_g.view(shape) if self.mask is not None: weight = weight * self.mask return weight def forward(self, inputs): """ :param inputs: (N, C_in, H, W) :return: (N, C_out, H, W) """ if self.generating: if self.generating_reset: self.generating_reset = False if self.kernel_width != (1, 1): self._input_cache = inputs else: return self.forward_generate(inputs) h = F.conv2d(inputs, self.weight, bias=self.bias, stride=self.stride, padding=self.padding, dilation=self.dilation) if self.activation is not None: h = self.activation(h) return h def forward_generate(self, inputs): """Calculates forward for the last position in `inputs` Only implemented for kernel widths (1, 1) and (1, 3) and stride (1, 1). If the kernel width is (1, 3), causal must be True. :param inputs: tensor(N, C_in, 1, 1) :return: tensor(N, C_out, 1, 1) """ if self._weight is None: self._weight = self.weight self._weight = self._weight.transpose(0, 1) if self.kernel_width == (1, 1): h = inputs[:, :, 0, -1] @ self._weight[:, :, 0, 0] + self.bias.view(1, self.out_channels) elif self.kernel_width == (1, 3): h = inputs[:, :, 0, -1] @ self._weight[:, :, 0, 1] if self.dilation[1] < self._input_cache.size(3): h += self._input_cache[:, :, 0, -self.dilation[1]] @ self._weight[:, :, 0, 0] h += self.bias.view(1, self.out_channels) self._input_cache = torch.cat([self._input_cache, inputs], dim=3) else: raise HyperparameterError(f"Generate not supported for kernel width {self.kernel_width}.") if self.activation is not None: h = self.activation(h) return h.unsqueeze(-1).unsqueeze(-1) def extra_repr(self): s = '{in_channels}, {out_channels}, kernel_size={kernel_width}' if self.stride != (1,) * len(self.stride): s += ', stride={stride}' if self.dilation != (1,) * len(self.dilation): s += ', dilation={dilation}' if self.causal: s += ', causal=True' return s.format(**self.__dict__)
class BinaryGatedLinear(Module): """ Linear layer with stochastic binary gates """ def __init__(self, in_features, out_features, l0_strength=1., l2_strength=1., learn_weight=True, bias=True, droprate_init=0.5, **kwargs): """ :param in_features: Input dimensionality :param out_features: Output dimensionality :param bias: Whether we use a bias :param l2_strength: Strength of the L2 penalty :param droprate_init: Dropout rate that the gates will be initialized to :param l0_strength: Strength of the L0 penalty """ super(BinaryGatedLinear, self).__init__() self.in_features = in_features self.out_features = out_features self.l0_strength = l0_strength self.l2_strength = l2_strength self.floatTensor = (torch.FloatTensor if not torch.cuda.is_available() else torch.cuda.FloatTensor) weight = torch.Tensor(out_features, in_features) if learn_weight: self.weight = Parameter(weight) else: self.register_buffer("weight", weight) self.logit_p1 = Parameter(torch.Tensor(out_features, in_features)) self.droprate_init = droprate_init if droprate_init != 0. else 0.5 self.use_bias = False if bias: b = torch.Tensor(out_features) if learn_weight: self.bias = Parameter(b) else: self.register_buffer("bias", b) self.use_bias = True self.reset_parameters() def reset_parameters(self): init.kaiming_normal_(self.weight, mode="fan_out") self.logit_p1.data.normal_( math.log(1 - self.droprate_init) - math.log(self.droprate_init), 1e-2) if self.use_bias: self.bias.data.fill_(0) def constrain_parameters(self, **kwargs): pass def regularization(self): """ Expected L0 norm under the stochastic gates, takes into account and re-weights also a potential L2 penalty """ p1 = torch.sigmoid(self.logit_p1) weight_decay_ungated = .5 * self.l2_strength * self.weight.pow(2) weight_l2_l0 = torch.sum( (weight_decay_ungated + self.l0_strength) * p1) bias_l2 = (0 if not self.use_bias else torch.sum(.5 * self.l2_strength * self.bias.pow(2))) return -weight_l2_l0 - bias_l2 def sample_weight(self): u = self.floatTensor(self.logit_p1.size()).uniform_(0, 1) p1 = torch.sigmoid(self.logit_p1) mask = p1 > u def cc_to_p1(grad): ratio = p1 / (1 - p1) p1.backward(grad * torch.where(mask, 1 / ratio, ratio)) return grad z = mask.float() z.requires_grad_() z.register_hook(cc_to_p1) return self.weight * z def forward(self, x): return F.linear(x, self.sample_weight(), (self.bias if self.use_bias else None)) def get_expected_nonzeros(self): expected_gates = torch.sigmoid(self.logit_p1) return expected_gates.sum( dim=tuple(range(1, len(expected_gates.shape)))).detach() def get_inference_nonzeros(self): u = self.floatTensor(self.logit_p1.size()).uniform_(0, 1) inference_gates = torch.sigmoid(self.logit_p1) > u return inference_gates.sum( dim=tuple(range(1, len(inference_gates.shape)))).detach()
class group_relaxed_L1L2Conv2d(Module): """Implementation of TF1 regularization for the feature maps of a convolutional layer""" def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, lamba=1., alpha=1., beta=4., weight_decay=1., **kwargs): """ :param in_channels: Number of input channels :param out_channels: Number of output channels :param kernel_size: size of the kernel :param stride: stride for the convolution :param padding: padding for the convolution :param dilation: dilation factor for the convolution :param groups: how many groups we will assume in the convolution :param bias: whether we will use a bias :param lamba: strength of the TFL regularization """ super(group_relaxed_L1L2Conv2d, self).__init__() self.floatTensor = torch.FloatTensor if not torch.cuda.is_available( ) else torch.cuda.FloatTensor self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = pair(kernel_size) self.stride = pair(stride) self.padding = pair(padding) self.dilation = pair(dilation) self.output_padding = pair(0) self.groups = groups self.lamba = lamba self.alpha = alpha self.beta = beta self.lamba1 = self.lamba / self.beta self.weight_decay = weight_decay self.weight = Parameter( torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) self.u = torch.rand(out_channels, in_channels // groups, *self.kernel_size) self.u = self.u.to('cuda') if bias: self.bias = Parameter(torch.Tensor(out_channels)) else: self.register_parameter('bias', None) self.reset_parameters() self.input_shape = None print(self) def reset_parameters(self): init.kaiming_normal(self.weight, mode='fan_in') if self.bias is not None: self.bias.data.normal_(0, 1e-2) def constrain_parameters(self, **kwargs): norm_w = self.weight.data.norm(p=float('inf')) if norm_w > self.lamba1: m = Softshrink(self.lamba1) z = m(self.weight.data) self.u.data = z * (z.data.norm(p=2) + self.alpha * self.lamba1) / (z.data.norm(p=2)) elif norm_w == self.lamba1: self.u = self.weight.clone() self.u[self.u.abs() < lamba1] = 0 n = torch.sum(self.u != 0) self.u[self.u != 0] = self.weight.sign( ) * self.alpha * self.lamba1 / (n**(1 / 2)) elif (1 - self.alpha) * self.lamba1 < norm_w and norm_w < self.lamba1: self.u = self.weight.clone() max_idx = np.unravel_index(torch.argmax(self.u.cpu(), None), self.u.shape) max_value_sign = self.u[max_idx].sign() self.u[:] = 0 self.u[max_idx] = (norm_w + (self.alpha - 1) * self.lamba1) * max_value_sign else: self.u = self.weight.clone() self.u[:] = 0 def grow_beta(self, growth_factor): self.beta = self.beta * growth_factor self.lamba1 = self.lamba / self.beta def _reg_w(self, **kwargs): logpw = -self.beta * torch.sum( 0.5 * self.weight.add(-self.u).pow(2)) - self.lamba * np.sqrt( self.in_channels * self.kernel_size[0] * self.kernel_size[1] ) * torch.sum( torch.pow(torch.sum(self.weight.pow(2), 3).sum(2).sum(1), 0.5)) logpb = 0 if self.bias is not None: logpb = -torch.sum(self.weight_decay * .5 * (self.bias.pow(2))) return logpw + logpb def regularization(self): return self._reg_w() def count_zero_u(self): total = np.prod(self.u.size()) zero = total - self.u.nonzero().size(0) return zero def count_zero_w(self): return torch.sum((self.weight.abs() < 1e-5).int()).item() def count_active_neuron(self): return torch.sum((torch.sum(self.weight.abs(), 3).sum(2).sum(1) / (self.in_channels * self.kernel_size[0] * self.kernel_size[1])) > 1e-5).item() def count_total_neuron(self): return self.out_channels def count_weight(self): return np.prod(self.u.size()) def count_expected_flops_and_l0(self): #ppos = self.out_channels ppos = torch.sum( torch.sum(self.weight.abs(), 3).sum(2).sum(1) > 0.001).item() n = self.kernel_size[0] * self.kernel_size[1] * self.in_channels flops_per_instance = n + (n - 1) num_instances_per_filter = ( (self.input_shape[1] - self.kernel_size[0] + 2 * self.padding[0]) / self.stride[0]) + 1 num_instances_per_filter *= ( (self.input_shape[2] - self.kernel_size[1] + 2 * self.padding[1]) / self.stride[1]) + 1 flops_per_filter = num_instances_per_filter * flops_per_instance expected_flops = flops_per_filter * ppos expected_l0 = n * ppos if self.bias is not None: expected_flops += num_instances_per_filter * ppos expected_l0 += ppos return expected_flops, expected_l0 def forward(self, input_): if self.input_shape is None: self.input_shape = input_.size() output = F.conv2d(input_, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) return output def __repr__(self): s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size} ' ', stride={stride}') if self.padding != (0, ) * len(self.padding): s += ', padding={padding}' if self.dilation != (1, ) * len(self.dilation): s += ', dilation={dilation}' if self.output_padding != (0, ) * len(self.output_padding): s += ', output_padding={output_padding}' if self.groups != 1: s += ', groups={groups}' if self.bias is None: s += ', bias=False' s += ')' return s.format(name=self.__class__.__name__, **self.__dict__)
class LinearGroupNJ(Module): """Fully Connected Group Normal-Jeffrey's layer (aka Group Variational Dropout). References: [1] Kingma, Diederik P., Tim Salimans, and Max Welling. "Variational dropout and the local reparameterization trick." NIPS (2015). [2] Molchanov, Dmitry, Arsenii Ashukha, and Dmitry Vetrov. "Variational Dropout Sparsifies Deep Neural Networks." ICML (2017). [3] Louizos, Christos, Karen Ullrich, and Max Welling. "Bayesian Compression for Deep Learning." NIPS (2017). """ def __init__(self, in_features, out_features, cuda=False, init_weight=None, init_bias=None, clip_var=None): super(LinearGroupNJ, self).__init__() self.cuda = cuda self.in_features = in_features self.out_features = out_features self.clip_var = clip_var self.deterministic = False # flag is used for compressed inference # trainable params according to Eq.(6) # dropout params self.z_mu = Parameter(torch.Tensor(in_features)) self.z_logvar = Parameter(torch.Tensor(in_features)) # = z_mu^2 * alpha # weight params self.weight_mu = Parameter(torch.Tensor(out_features, in_features)) self.weight_logvar = Parameter(torch.Tensor(out_features, in_features)) self.bias_mu = Parameter(torch.Tensor(out_features)) self.bias_logvar = Parameter(torch.Tensor(out_features)) # init params either random or with pretrained net self.reset_parameters(init_weight, init_bias) # activations for kl self.sigmoid = nn.Sigmoid() self.softplus = nn.Softplus() # numerical stability param self.epsilon = 1e-8 def reset_parameters(self, init_weight, init_bias): # init means stdv = 1. / math.sqrt(self.weight_mu.size(1)) self.z_mu.data.normal_(1, 1e-2) if init_weight is not None: self.weight_mu.data = torch.Tensor(init_weight) else: self.weight_mu.data.normal_(0, stdv) if init_bias is not None: self.bias_mu.data = torch.Tensor(init_bias) else: self.bias_mu.data.fill_(0) # init logvars self.z_logvar.data.normal_(-9, 1e-2) self.weight_logvar.data.normal_(-9, 1e-2) self.bias_logvar.data.normal_(-9, 1e-2) def clip_variances(self): if self.clip_var: self.weight_logvar.data.clamp_(max=math.log(self.clip_var)) self.bias_logvar.data.clamp_(max=math.log(self.clip_var)) def get_log_dropout_rates(self): log_alpha = self.z_logvar - torch.log(self.z_mu.pow(2) + self.epsilon) return log_alpha def compute_posterior_params(self): weight_var, z_var = self.weight_logvar.exp(), self.z_logvar.exp() self.post_weight_var = self.z_mu.pow(2) * weight_var + z_var * self.weight_mu.pow(2) + z_var * weight_var self.post_weight_mu = self.weight_mu * self.z_mu return self.post_weight_mu, self.post_weight_var def forward(self, x): if self.deterministic: assert self.training == False, "Flag deterministic is True. This should not be used in training." return F.linear(x, self.post_weight_mu, self.bias_mu) batch_size = x.size()[0] # compute z # note that we reparametrise according to [2] Eq. (11) (not [1]) z = reparametrize(self.z_mu.repeat(batch_size, 1), self.z_logvar.repeat(batch_size, 1), sampling=self.training) # apply local reparametrisation trick see [1] Eq. (6) # to the parametrisation given in [3] Eq. (6) xz = x * z mu_activations = F.linear(xz, self.weight_mu, self.bias_mu) var_activations = F.linear(xz.pow(2), self.weight_logvar.exp(), self.bias_logvar.exp()) return reparametrize(mu_activations, var_activations.log(), sampling=self.training) def kl_divergence(self): # KL(q(z)||p(z)) # we use the kl divergence approximation given by [2] Eq.(14) k1, k2, k3 = 0.63576, 1.87320, 1.48695 log_alpha = self.get_log_dropout_rates() KLD = -torch.sum(k1 * self.sigmoid(k2 + k3 * log_alpha) - 0.5 * self.softplus(-log_alpha) - k1) # KL(q(w|z)||p(w|z)) # we use the kl divergence given by [3] Eq.(8) KLD_element = -0.5 * self.weight_logvar + 0.5 * (self.weight_logvar.exp() + self.weight_mu.pow(2)) - 0.5 KLD += torch.sum(KLD_element) # KL bias KLD_element = -0.5 * self.bias_logvar + 0.5 * (self.bias_logvar.exp() + self.bias_mu.pow(2)) - 0.5 KLD += torch.sum(KLD_element) return KLD def __repr__(self): return self.__class__.__name__ + ' (' \ + str(self.in_features) + ' -> ' \ + str(self.out_features) + ')'
class LinearGroupNJ(BayesianLayers): """Fully Connected Group Normal-Jeffrey's layer (aka Group Variational Dropout). References: [1] Kingma, Diederik P., Tim Salimans, and Max Welling. "Variational dropout and the local reparameterization trick." NIPS (2015). [2] Molchanov, Dmitry, Arsenii Ashukha, and Dmitry Vetrov. "Variational Dropout Sparsifies Deep Neural Networks." ICML (2017). [3] Louizos, Christos, Karen Ullrich, and Max Welling. "Bayesian Compression for Deep Learning." NIPS (2017). """ def __init__(self, in_features, out_features, cuda=False, init_weight=None, init_bias=None, clip_var=None): super(LinearGroupNJ, self).__init__() self.cuda = cuda self.in_features = in_features self.out_features = out_features self.clip_var = clip_var self.deterministic = False # flag is used for compressed inference # trainable params according to Eq.(6) # dropout params self.z_mu = Parameter(torch.Tensor(in_features)) self.z_logvar = Parameter(torch.Tensor(in_features)) # = z_mu^2 * alpha # weight params self.weight_mu = Parameter(torch.Tensor(out_features, in_features)) self.weight_logvar = Parameter(torch.Tensor(out_features, in_features)) self.bias_mu = Parameter(torch.Tensor(out_features)) self.bias_logvar = Parameter(torch.Tensor(out_features)) # init params either random or with pretrained net self.reset_parameters(init_weight, init_bias) # activations for kl self.sigmoid = nn.Sigmoid() self.softplus = nn.Softplus() # numerical stability param self.epsilon = 1e-8 def reset_parameters(self, init_weight, init_bias): # init means stdv = 1. / math.sqrt(self.weight_mu.size(1)) self.z_mu.data.normal_(1, 1e-2) if init_weight is not None: self.weight_mu.data = torch.Tensor(init_weight) else: self.weight_mu.data.normal_(0, stdv) if init_bias is not None: self.bias_mu.data = torch.Tensor(init_bias) else: self.bias_mu.data.fill_(0) # init logvars self.z_logvar.data.normal_(-9, 1e-2) self.weight_logvar.data.normal_(-9, 1e-2) self.bias_logvar.data.normal_(-9, 1e-2) def clip_variances(self): if self.clip_var: self.weight_logvar.data.clamp_(max=math.log(self.clip_var)) self.bias_logvar.data.clamp_(max=math.log(self.clip_var)) def get_log_dropout_rates(self): log_alpha = self.z_logvar - torch.log(self.z_mu.pow(2) + self.epsilon) return log_alpha def compute_posterior_params(self): weight_var, z_var = self.weight_logvar.exp(), self.z_logvar.exp() self.post_weight_var = self.z_mu.pow(2) * weight_var + z_var * self.weight_mu.pow(2) + z_var * weight_var self.post_weight_mu = self.weight_mu * self.z_mu # print("self.z_mu.pow(2): ", self.z_mu.pow(2).size()) # print("weight_var: ", weight_var.size()) # print("z_var: ", z_var.size()) # print("self.weight_mu.pow(2): ", self.weight_mu.pow(2).size()) # print("weight_var: ", weight_var.size()) # print("post_weight_mu: ", self.post_weight_mu.size()) # print("post_weight_var: ", self.post_weight_var.size()) return self.post_weight_mu, self.post_weight_var def forward(self, x): if self.deterministic: assert self.training == False, "Flag deterministic is True. This should not be used in training." return F.linear(x, self.post_weight_mu, self.bias_mu) batch_size = x.size()[0] # compute z # note that we reparametrise according to [2] Eq. (11) (not [1]) z = reparametrize(self.z_mu.repeat(batch_size, 1), self.z_logvar.repeat(batch_size, 1), sampling=self.training, cuda=self.cuda) # apply local reparametrisation trick see [1] Eq. (6) # to the parametrisation given in [3] Eq. (6) xz = x * z mu_activations = F.linear(xz, self.weight_mu, self.bias_mu) var_activations = F.linear(xz.pow(2), self.weight_logvar.exp(), self.bias_logvar.exp()) return reparametrize(mu_activations, var_activations.log(), sampling=self.training, cuda=self.cuda) def kl_divergence(self): # KL(q(z)||p(z)) # we use the kl divergence approximation given by [2] Eq.(14) k1, k2, k3 = 0.63576, 1.87320, 1.48695 log_alpha = self.get_log_dropout_rates() KLD = -torch.sum(k1 * self.sigmoid(k2 + k3 * log_alpha) - 0.5 * self.softplus(-log_alpha) - k1) # KL(q(w|z)||p(w|z)) # we use the kl divergence given by [3] Eq.(8) KLD_element = -0.5 * self.weight_logvar + 0.5 * (self.weight_logvar.exp() + self.weight_mu.pow(2)) - 0.5 KLD += torch.sum(KLD_element) # KL bias KLD_element = -0.5 * self.bias_logvar + 0.5 * (self.bias_logvar.exp() + self.bias_mu.pow(2)) - 0.5 KLD += torch.sum(KLD_element) return KLD def __repr__(self): return self.__class__.__name__ + ' (' \ + str(self.in_features) + ' -> ' \ + str(self.out_features) + ')'
class group_relaxed_L1L2Dense(Module): """Implementation of TFL regularization for the input units of a fully connected layer""" def __init__(self, in_features, out_features, bias=True, lamba=1., alpha=1., beta=4., weight_decay=1., **kwargs): """ :param in_features: input dimensionality :param out_features: output dimensionality :param bias: whether we use bias :param lamba: strength of the TF1 regularization """ super(group_relaxed_L1L2Dense, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = Parameter(torch.Tensor(in_features, out_features)) self.u = torch.rand(in_features, out_features) self.u = self.u.to('cuda') if bias: self.bias = Parameter(torch.Tensor(out_features)) else: self.register_parameter('bias', None) self.lamba = lamba self.alpha = alpha self.beta = beta self.lamba1 = self.lamba / self.beta self.weight_decay = weight_decay self.floatTensor = torch.FloatTensor if not torch.cuda.is_available( ) else torch.cuda.FloatTensor self.reset_parameters() print(self) def reset_parameters(self): init.kaiming_normal(self.weight, mode='fan_out') if self.bias is not None: self.bias.data.normal_(0, 1e-2) def constrain_parameters(self, **kwargs): norm_w = self.weight.data.norm(p=float('inf')) if norm_w > self.lamba1: m = Softshrink(self.lamba1) z = m(self.weight.data) self.u.data = z * (z.data.norm(p=2) + self.alpha * self.lamba1) / (z.data.norm(p=2)) elif norm_w == self.lamba1: self.u = self.weight.clone() self.u[self.u.abs() < lamba1] = 0 n = torch.sum(self.u != 0) self.u[self.u != 0] = self.weight.sign( ) * self.alpha * self.lamba1 / (n**(1 / 2)) elif (1 - self.alpha) * self.lamba1 < norm_w and norm_w < self.lamba1: self.u = self.weight.clone() max_idx = np.unravel_index(torch.argmax(self.u.cpu(), None), self.u.shape) max_value_sign = self.u[max_idx].sign() self.u[:] = 0 self.u[max_idx] = (norm_w + (self.alpha - 1) * self.lamba1) * max_value_sign else: self.u = self.weight.clone() self.u[:] = 0 def grow_beta(self, growth_factor): self.beta = self.beta * growth_factor self.lamba1 = self.lamba / self.beta def _reg_w(self, **kwargs): logpw = -self.beta * torch.sum( 0.5 * self.weight.add(-self.u).pow(2)) - self.lamba * np.sqrt( self.out_features) * torch.sum( torch.pow(torch.sum(self.weight.pow(2), 1), 0.5)) logpb = 0 if self.bias is not None: logpb = -torch.sum(self.weight_decay * .5 * (self.bias.pow(2))) return logpw + logpb def regularization(self): return self._reg_w() def count_zero_u(self): total = np.prod(self.u.size()) zero = total - self.u.nonzero().size(0) return zero def count_zero_w(self): return torch.sum((self.weight.abs() < 1e-5).int()).item() def count_weight(self): return np.prod(self.u.size()) def count_active_neuron(self): return torch.sum( torch.sum(self.weight.abs() / self.out_features, 1) > 1e-5).item() def count_total_neuron(self): return self.in_features def count_expected_flops_and_l0(self): ppos = torch.sum(self.weight.abs() > 0.000001).item() expected_flops = (2 * ppos - 1) * self.out_features expected_l0 = ppos * self.out_features if self.bias is not None: expected_flops += self.out_features expected_l0 += self.out_features return expected_flops, expected_l0 def forward(self, input): output = input.mm(self.weight) if self.bias is not None: output.add_(self.bias.view(1, self.out_features).expand_as(output)) return output def __repr__(self): return self.__class__.__name__+' (' \ + str(self.in_features) + ' -> ' \ + str(self.out_features) + ', lambda: ' \ + str(self.lamba) + ')'
class HSDense(Module): def __init__(self, in_features, out_features, bias=True, prior_std=1., prior_std_z=1., dof=1., **kwargs): super(HSDense, self).__init__() self.in_features = in_features self.out_features = out_features self.prior_std = prior_std self.mean_w = Parameter(torch.Tensor(in_features, out_features)) self.logvar_w = Parameter(torch.Tensor(in_features, out_features)) self.qz_mean = Parameter(torch.Tensor(in_features)) self.qz_logvar = Parameter(torch.Tensor(in_features)) self.dof = dof self.prior_std_z = prior_std_z self.use_bias = False if bias: self.mean_bias = Parameter(torch.Tensor(out_features)) self.logvar_bias = Parameter(torch.Tensor(out_features)) self.use_bias = True self.floatTensor = torch.FloatTensor if not torch.cuda.is_available( ) else torch.cuda.FloatTensor self.reset_parameters() print(self) def reset_parameters(self): init.kaiming_normal_(self.mean_w, mode='fan_out') self.logvar_w.data.normal_(-9., 1e-4) self.qz_mean.data.normal_(math.log(math.exp(1) - 1), 1e-3) self.qz_logvar.data.normal_(math.log(0.1), 1e-4) if self.use_bias: self.mean_bias.data.normal_(0, 1e-2) self.logvar_bias.data.normal_(-9., 1e-4) def constrain_parameters(self, thres_std=1.): self.logvar_w.data.clamp_(max=2. * math.log(thres_std)) if self.use_bias: self.logvar_bias.data.clamp_(max=2. * math.log(thres_std)) def eq_logpw(self): logpw = -.5 * math.log( 2 * math.pi * self.prior_std**2) - .5 * self.logvar_w.exp().div( self.prior_std**2) logpw -= .5 * self.mean_w.pow(2).div(self.prior_std**2) logpb = 0. if self.use_bias: logpb = - .5 * math.log(2 * math.pi * self.prior_std ** 2) - .5 * self.logvar_bias.exp().div \ (self.prior_std ** 2) logpb -= .5 * self.mean_bias.pow(2).div(self.prior_std**2) return torch.sum(logpw) + torch.sum(logpb) def eq_logqw(self): logqw = -torch.sum(.5 * (math.log(2 * math.pi) + self.logvar_w + 1)) logqb = 0. if self.use_bias: logqb = -torch.sum(.5 * (math.log(2 * math.pi) + self.logvar_bias + 1)) return logqw + logqb def kldiv_aux(self): z = self.sample_z(1) z = z.view(self.in_features) logqm = -torch.sum(.5 * (math.log(2 * math.pi) + self.qz_logvar + 1)) logqm = logqm.add(-torch.sum(F.sigmoid(z.exp().add(-1).log()).log())) logpm = torch.sum( 2 * math.lgamma(.5 * (self.dof + 1)) - math.lgamma(.5 * self.dof) - math.log(self.prior_std_z) - .5 * math.log(self.dof * math.pi) - .5 * (self.dof + 1) * torch.log(1. + z.pow(2) / (self.dof * self.prior_std_z**2))) return logpm - logqm def kldiv(self): return self.kldiv_aux() + self.eq_logpw() - self.eq_logqw() def get_eps(self, size): eps = self.floatTensor(size).normal_() eps = Variable(eps) return eps def sample_z(self, batch_size): z = self.qz_mean.view(1, self.in_features) if self.training: eps = self.get_eps(self.floatTensor(batch_size, self.in_features)) z = z + eps.mul( self.qz_logvar.view(1, self.in_features).mul(0.5).exp_()) return F.softplus(z) def sample_W(self): W = self.mean_w if self.training: eps = self.get_eps(self.mean_w.size()) W = W.add(eps.mul(self.logvar_w.mul(0.5).exp_())) return W def sample_b(self): b = self.mean_bias if self.training: eps = self.get_eps(self.mean_bias.size()) b = b.add(eps.mul(self.logvar_bias.mul(0.5).exp_())) return b def get_mean_x(self, input): mean_xin = input.mm(self.mean_w) if self.use_bias: mean_xin = mean_xin.add(self.mean_bias.view(1, self.out_features)) return mean_xin def get_var_x(self, input): var_xin = input.pow(2).mm(self.logvar_w.exp()) if self.use_bias: var_xin = var_xin.add(self.logvar_bias.exp().view( 1, self.out_features)) return var_xin def forward(self, input): # sampling batch_size = input.size(0) z = self.sample_z(batch_size) xin = input.mul(z) mean_xin = self.get_mean_x(xin) output = mean_xin if self.training: var_xin = self.get_var_x(xin) eps = self.get_eps(self.floatTensor(batch_size, self.out_features)) output = output.add(var_xin.sqrt().mul(eps)) return output def __repr__(self): return self.__class__.__name__ + ' (' \ + str(self.in_features) + ' -> ' \ + str(self.out_features) + ', dof: ' \ + str(self.dof) + ', prior_std_z: ' \ + str(self.prior_std_z) + ')'