class VariationalDropout(nn.Module): def __init__(self, input_size, out_size, log_sigma2=-10, threshold=3): """ :param input_size: An int of input size :param log_sigma2: Initial value of log sigma ^ 2. It is crusial for training since it determines initial value of alpha :param threshold: Value for thresholding of validation. If log_alpha > threshold, then weight is zeroed :param out_size: An int of output size """ super(VariationalDropout, self).__init__() self.input_size = input_size self.out_size = out_size self.theta = Parameter(t.FloatTensor(input_size, out_size)) self.bias = Parameter(t.Tensor(out_size)) self.log_sigma2 = Parameter(t.FloatTensor(input_size, out_size).fill_(log_sigma2)) self.reset_parameters() self.threshold = threshold def forward(self, input): # Local Reparameterization Trick log_alpha = self.clip(self.log_sigma2 - t.log(self.theta ** 2)) kld = self.kld(log_alpha) if not self.training: mask = log_alpha > self.threshold return t.addmm(self.bias, input, self.theta.masked_fill(mask, 0)) mu = t.mm(input, self.theta) std = t.sqrt(t.mm(input ** 2, self.log_sigma2.exp()) + 1e-6) eps = Variable(t.randn(*mu.size())) if input.is_cuda: eps = eps.cuda() return std * eps + mu + self.bias, kld def reset_parameters(self): stdv = 1. / math.sqrt(self.out_size) self.theta.data.uniform_(-stdv, stdv) self.bias.data.uniform_(-stdv, stdv) def clip(self, input, to=8): input = input.masked_fill(input < -to, -to) input = input.masked_fill(input > to, to) return input def kld(self, log_alpha): # in paper "Variational Dropout Sparsifies Deep Neural Networks" k = [0.63576, 1.87320, 1.48695] first_term = k[0] * t.sigmoid(k[1] + k[2] * log_alpha) second_term = 0.5 * t.log(1 + t.exp(-log_alpha)) return - (first_term - second_term - k[0]).sum() / (self.input_size * self.out_size)
class VDropLinear(Module): def __init__(self, in_features, out_features, bias=True): super().__init__() self.w_mu = Parameter(torch.Tensor(out_features, in_features)) init.kaiming_normal_(self.w_mu, mode="fan_out") self.w_logsigma2 = Parameter(torch.Tensor(out_features, in_features)) self.w_logsigma2.data.fill_(-10) if bias: self.bias = Parameter(torch.Tensor(out_features)) self.bias.data.fill_(0) else: self.bias = None self.threshold = 3 self.epsilon = 1e-8 self.tensor = (torch.FloatTensor if not torch.cuda.is_available() else torch.cuda.FloatTensor) def compute_mask(self): w_logalpha = self.w_logsigma2 - (self.w_mu ** 2 + self.epsilon).log() return (w_logalpha < self.threshold).float() def forward(self, x): if self.training: y_mu = F.linear(x, self.w_mu, self.bias) # Avoid sqrt(0), otherwise a divide-by-zero occurs during backprop. y_sigma = F.linear( x ** 2, self.w_logsigma2.exp() ).clamp(self.epsilon).sqrt() rv = self.tensor(y_mu.size()).normal_() return y_mu + (rv * y_sigma) else: return F.linear(x, self.w_mu * self.compute_mask(), self.bias) def regularization(self): k1, k2, k3 = 0.63576, 1.8732, 1.48695 w_logalpha = self.w_logsigma2 - (self.w_mu ** 2 + self.epsilon).log() return -(k1 * torch.sigmoid(k2 + k3 * w_logalpha) - 0.5 * F.softplus(-w_logalpha) - k1).sum() def get_inference_nonzeros(self): return self.compute_mask().int().sum(dim=1) def count_inference_flops(self): # For each unit, multiply with its n inputs then do n - 1 additions. # To capture the -1, subtract it, but only in cases where there is at # least one weight. nz_by_unit = self.get_inference_nonzeros() multiplies = torch.sum(nz_by_unit) adds = multiplies - torch.sum(nz_by_unit > 0) return multiplies.item(), adds.item() def weight_size(self): return self.w_mu.size()
class OptNetEq(nn.Module): def __init__(self, n, Qpenalty, qp_solver, trueInit=False): super().__init__() self.qp_solver = qp_solver nx = (n**2)**3 self.Q = Variable(Qpenalty * torch.eye(nx).double().cuda()) self.Q_idx = spa.csc_matrix(self.Q.detach().cpu().numpy()).nonzero() self.G = Variable(-torch.eye(nx).double().cuda()) self.h = Variable(torch.zeros(nx).double().cuda()) t = get_sudoku_matrix(n) if trueInit: self.A = Parameter(torch.DoubleTensor(t).cuda()) else: self.A = Parameter(torch.rand(t.shape).double().cuda()) self.log_z0 = Parameter(torch.zeros(nx).double().cuda()) # self.b = Variable(torch.ones(self.A.size(0)).double().cuda()) if self.qp_solver == 'osqpth': t = torch.cat((self.A, self.G), dim=0) self.AG_idx = spa.csc_matrix(t.detach().cpu().numpy()).nonzero() # @profile def forward(self, puzzles): nBatch = puzzles.size(0) p = -puzzles.view(nBatch, -1) b = self.A.mv(self.log_z0.exp()) if self.qp_solver == 'qpth': y = QPFunction(verbose=-1)(self.Q, p.double(), self.G, self.h, self.A, b).float().view_as(puzzles) elif self.qp_solver == 'osqpth': _l = torch.cat((b, torch.full(self.h.shape, float('-inf'), device=self.h.device, dtype=self.h.dtype)), dim=0) _u = torch.cat((b, self.h), dim=0) Q_data = self.Q[self.Q_idx[0], self.Q_idx[1]] AG = torch.cat((self.A, self.G), dim=0) AG_data = AG[self.AG_idx[0], self.AG_idx[1]] y = OSQP(self.Q_idx, self.Q.shape, self.AG_idx, AG.shape, diff_mode=DiffModes.FULL)(Q_data, p.double(), AG_data, _l, _u).float().view_as(puzzles) else: assert False return y
class VDropLinear(nn.Module): def __init__(self, in_features, out_features, bias=True): super().__init__() self.weight = Parameter(torch.Tensor(out_features, in_features)) init.kaiming_normal_(self.weight, mode="fan_out") self.w_logvar = Parameter(torch.Tensor(out_features, in_features)) self.w_logvar.data.fill_(-10) if bias: self.bias = Parameter(torch.Tensor(out_features)) self.bias.data.fill_(0) else: self.bias = None self.threshold = 3 self.epsilon = 1e-8 self.tensor_constructor = (torch.FloatTensor if not torch.cuda.is_available() else torch.cuda.FloatTensor) def constrain_parameters(self): self.w_logvar.data.clamp_(min=-10., max=10.) def compute_mask(self): w_logalpha = self.w_logvar - (self.weight ** 2 + self.epsilon).log() return (w_logalpha < self.threshold).float() def forward(self, x): if self.training: return vdrop_linear_forward(x, lambda: self.weight, lambda: self.w_logvar.exp(), self.bias, self.tensor_constructor, self.epsilon) else: return F.linear(x, self.weight * self.compute_mask(), self.bias) def regularization(self): w_logalpha = self.w_logvar - (self.weight ** 2 + self.epsilon).log() return vdrop_regularization(w_logalpha).sum() def get_inference_nonzeros(self): return self.compute_mask().int().sum(dim=1) def count_inference_flops(self): # For each unit, multiply with its n inputs then do n - 1 additions. # To capture the -1, subtract it, but only in cases where there is at # least one weight. nz_by_unit = self.get_inference_nonzeros() multiplies = torch.sum(nz_by_unit) adds = multiplies - torch.sum(nz_by_unit > 0) return multiplies.item(), adds.item() def weight_size(self): return self.weight.size()
class VariationalDropout(nn.Module): def __init__(self, input_size, out_size, log_sigma2=-10, threshold=3): """ :param input_size: An int of input size :param log_sigma2: Initial value of log_sigma^2 (crucial for training as it determines initial value of alpha) :param threshold: Value for thresholding of validation. If log_alpha > threshold, then weight is zeroed :param out_size: An int of output size """ super(VariationalDropout, self).__init__() self.input_size = input_size self.out_size = out_size self.theta = Parameter(torch.FloatTensor(input_size, out_size)) self.bias = Parameter(torch.Tensor(out_size)) self.log_sigma2 = Parameter( torch.FloatTensor(input_size, out_size).fill_(log_sigma2)) self.reset_parameters() self.k = [0.63576, 1.87320, 1.48695] self.threshold = threshold def reset_parameters(self): stdv = 1. / math.sqrt(self.out_size) self.theta.data.uniform_(-stdv, stdv) self.bias.data.uniform_(-stdv, stdv) @staticmethod def clip(input, to=8): input = input.masked_fill(input < -to, -to) input = input.masked_fill(input > to, to) return input def kld(self, log_alpha): first_term = self.k[0] * F.sigmoid(self.k[1] + self.k[2] * log_alpha) second_term = 0.5 * torch.log(1 + torch.exp(-log_alpha)) return -(first_term - second_term - self.k[0]).sum() / ( self.input_size * self.out_size) def forward(self, input): """ :param input: An float tensor with shape of [batch_size, input_size] :return: An float tensor with shape of [batch_size, out_size] and negative layer-kld estimation """ log_alpha = self.clip(self.log_sigma2 - torch.log(self.theta**2)) kld = self.kld(log_alpha) if not self.training: mask = log_alpha > self.threshold return torch.addmm(self.bias, input, self.theta.masked_fill(mask, 0)) mu = torch.mm(input, self.theta) std = torch.sqrt(torch.mm(input**2, self.log_sigma2.exp()) + 1e-6) eps = Variable(torch.randn(*mu.size())) if input.is_cuda: eps = eps.cuda() return std * eps + mu + self.bias, kld def max_alpha(self): log_alpha = self.log_sigma2 - self.theta**2 return torch.max(log_alpha.exp())
class VariationalDropout(nn.Module): def __init__(self, input_size, out_size, log_sigma2=-10, threshold=3): """ This module create a fully connected layer with variational dropout enabled :param input_size: An int of input size :param log_sigma2: Initial value of log sigma ^ 2. It is crucial for training since it determines initial value of alpha :param threshold: Value for thresholding of validation. If log_alpha > threshold, then weight is zeroed :param out_size: An int of output size """ super(VariationalDropout, self).__init__() self.input_size = input_size self.out_size = out_size self.theta = Parameter(t.FloatTensor( input_size, out_size)) # fully connected weight self.bias = Parameter(t.Tensor(out_size)) # bias self.log_sigma2 = Parameter( t.FloatTensor(input_size, out_size).fill_( log_sigma2)) # the Gaussian noise sample iid w.r.t each weight self.reset_parameters() self.k = [0.63576, 1.87320, 1.48695] self.threshold = threshold # as it said, this is used for zero the weight if the Gaussian noise ball has too large radius. def reset_parameters(self): stdv = 1. / math.sqrt(self.out_size) self.theta.data.uniform_(-stdv, stdv) self.bias.data.uniform_(-stdv, stdv) @staticmethod def clip(input, to=8): input = input.masked_fill(input < -to, -to) input = input.masked_fill(input > to, to) return input def kld(self, log_alpha): first_term = self.k[0] * F.sigmoid(self.k[1] + self.k[2] * log_alpha) second_term = 0.5 * t.log(1 + t.exp(-log_alpha)) return -(first_term - second_term - self.k[0]).sum() / ( self.input_size * self.out_size) def forward(self, input): """ :param input: An float tensor with shape of [batch_size, input_size] :return: An float tensor with shape of [batch_size, out_size] and negative layer-kld estimation """ log_alpha = self.clip(self.log_sigma2 - t.log(self.theta**2)) kld = self.kld(log_alpha) if not self.training: mask = log_alpha > self.threshold return t.addmm(self.bias, input, self.theta.masked_fill(mask, 0)) mu = t.mm(input, self.theta) std = t.sqrt(t.mm(input**2, self.log_sigma2.exp()) + 1e-6) eps = Variable( t.randn(*mu.size())) # sample from standard normal distribution if input.is_cuda: eps = eps.cuda() return std * eps + mu + self.bias, kld # a reparameterization trick to form the Gaussian dropout def max_alpha(self): log_alpha = self.log_sigma2 - self.theta**2 return t.max(log_alpha.exp())
class VariationalDropout(nn.Module): def __init__(self, input_size, out_size, log_sigma2=-10, threshold=3): """ :param input_size: An int of input size :param log_sigma2: Initial value of log sigma ^ 2. It is crusial for training since it determines initial value of alpha :param threshold: Value for thresholding of validation. If log_alpha > threshold, then weight is zeroed :param out_size: An int of output size """ super(VariationalDropout, self).__init__() self.input_size = input_size self.out_size = out_size self.theta = Parameter(t.FloatTensor(input_size, out_size)) self.bias = Parameter(t.Tensor(out_size)) self.log_sigma2 = Parameter( t.FloatTensor(input_size, out_size).fill_(log_sigma2)) self.reset_parameters() self.k = [0.63576, 1.87320, 1.48695] self.threshold = threshold def reset_parameters(self): stdv = 1. / math.sqrt(self.out_size) self.theta.data.uniform_(-stdv, stdv) self.bias.data.uniform_(-stdv, stdv) @staticmethod def clip(input, to=8.): input = input.masked_fill(input < -to, -to) input = input.masked_fill(input > to, to) return input def kld(self, log_alpha): first_term = self.k[0] * F.sigmoid(self.k[1] + self.k[2] * log_alpha) second_term = 0.5 * t.log(1 + t.exp(-log_alpha)) return (first_term - second_term - self.k[0]).sum() / (self.input_size * self.out_size) def forward(self, input, train): """ :param input: An float tensor with shape of [batch_size, input_size] :return: An float tensor with shape of [batch_size, out_size] and negative layer-kld estimation """ log_alpha = self.clip(self.log_sigma2 - t.log(self.theta**2)) fh = open("log_alpha_values_during_training.txt", 'a') fh.write( str(self.input_size) + "||||" + str(log_alpha.data.numpy()[0][0]) + "\n") fh.close() #print(log_alpha.data.numpy()[0][0]) kld = self.kld(log_alpha) if not train: mask = log_alpha > self.threshold if (t.nonzero(mask).dim() != 0): zeroed_weights = t.nonzero(mask).size(0) else: zeroed_weights = 0 total_weights = mask.size(0) * mask.size(1) print('number of zeroed weights is {}'.format(zeroed_weights)) print('total numer of weights is {}'.format(total_weights)) print('ratio for non zeroed weights is {}'.format( (total_weights - zeroed_weights) / total_weights)) return t.addmm(self.bias, input, self.theta.masked_fill(mask, 0)) mu = t.mm(input, self.theta) std = t.sqrt(t.mm(input**2, self.log_sigma2.exp()) + 1e-6) eps = Variable(t.randn(*mu.size())) if input.is_cuda: eps = eps.cuda() return std * eps + mu + self.bias, kld def max_alpha(self): log_alpha = self.log_sigma2 - self.theta**2 return t.max(log_alpha)
class HSConv2d(Module): '''Input channel noise''' def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, prior_std=1., prior_std_z=1., dof=1., **kwargs): super(HSConv2d, self).__init__() if in_channels % groups != 0: raise ValueError('in_channels must be divisible by groups') if out_channels % groups != 0: raise ValueError('out_channels must be divisible by groups') self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = pair(kernel_size) self.stride = pair(stride) self.padding = pair(padding) self.dilation = pair(dilation) self.output_padding = pair(0) self.groups = groups self.prior_std = prior_std self.prior_std_z = prior_std_z self.use_bias = False self.dof = dof self.mean_w = Parameter( torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) self.logvar_w = Parameter( torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) self.qz_mean = Parameter(torch.Tensor(in_channels // groups)) self.qz_logvar = Parameter(torch.Tensor(in_channels // groups)) self.dim_z = in_channels // groups if bias: self.mean_bias = Parameter(torch.Tensor(out_channels)) self.logvar_bias = Parameter(torch.Tensor(out_channels)) self.use_bias = True self.floatTensor = torch.FloatTensor if not torch.cuda.is_available( ) else torch.cuda.FloatTensor self.reset_parameters() print(self) def reset_parameters(self): init.kaiming_normal_(self.mean_w, mode='fan_in') self.logvar_w.data.normal_(-9., 1e-4) self.qz_mean.data.normal_(math.log(math.exp(1) - 1), 1e-3) self.qz_logvar.data.normal_(math.log(0.1), 1e-4) if self.use_bias: self.mean_bias.data.normal_(0, 1e-2) self.logvar_bias.data.normal_(-9., 1e-4) def constrain_parameters(self, thres_std=1.): self.logvar_w.data.clamp_(max=2. * math.log(thres_std)) if self.use_bias: self.logvar_bias.data.clamp_(max=2. * math.log(thres_std)) def eq_logpw(self): logpw = -.5 * math.log( 2 * math.pi * self.prior_std**2) - .5 * self.logvar_w.exp().div( self.prior_std**2) logpw -= .5 * self.mean_w.pow(2).div(self.prior_std**2) logpb = 0. if self.use_bias: logpb = -.5 * math.log(2 * math.pi * self.prior_std** 2) - .5 * self.logvar_bias.exp().div( self.prior_std**2) logpb -= .5 * self.mean_bias.pow(2).div(self.prior_std**2) logpb = torch.sum(logpb) return torch.sum(logpw) + logpb def eq_logqw(self): logqw = -torch.sum(.5 * (math.log(2 * math.pi) + self.logvar_w + 1)) logqb = 0. if self.use_bias: logqb = -torch.sum(.5 * (math.log(2 * math.pi) + self.logvar_bias + 1)) return logqw + logqb def kldiv_aux(self): z = self.sample_z(1) z = z.view(self.dim_z) logqm = -torch.sum(.5 * (math.log(2 * math.pi) + self.qz_logvar + 1)) logqm = logqm.add(-torch.sum(F.sigmoid(z.exp().add(-1).log()).log())) logpm = torch.sum( 2 * math.lgamma(.5 * (self.dof + 1)) - math.lgamma(.5 * self.dof) - math.log(self.prior_std_z) - .5 * math.log(self.dof * math.pi) - .5 * (self.dof + 1) * torch.log(1. + z.pow(2) / (self.dof * self.prior_std_z**2))) return logpm - logqm def kldiv(self): return self.kldiv_aux() + self.eq_logpw() - self.eq_logqw() def get_eps(self, size): eps = self.floatTensor(size).normal_() eps = Variable(eps) return eps def sample_z(self, batch_size): z = self.qz_mean.view(1, self.dim_z).expand(batch_size, self.dim_z) if self.training: eps = self.get_eps(self.floatTensor(batch_size, self.dim_z)) z = z.add( eps.mul( self.qz_logvar.view(1, self.dim_z).expand( batch_size, self.dim_z).mul(0.5).exp_())) z = z.contiguous().view(batch_size, self.dim_z, 1, 1) return F.softplus(z) def sample_W(self): W = self.mean_w if self.training: eps = self.get_eps(self.mean_w.size()) W = W.add(eps.mul(self.logvar_w.mul(0.5).exp_())) return W def sample_b(self): if not self.use_bias: return None b = self.mean_bias if self.training: eps = self.get_eps(self.mean_bias.size()) b = b.add(eps.mul(self.logvar_bias.mul(0.5).exp_())) return b def forward(self, input_): z = self.sample_z(input_.size(0)) W = self.sample_W() b = self.sample_b() return F.conv2d(input_.mul(z.expand_as(input_)), W, b, self.stride, self.padding, self.dilation, self.groups) def __repr__(self): s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}' ', stride={stride}, prior_std_z={prior_std_z}, dof={dof}') if self.padding != (0, ) * len(self.padding): s += ', padding={padding}' if self.dilation != (1, ) * len(self.dilation): s += ', dilation={dilation}' if self.output_padding != (0, ) * len(self.output_padding): s += ', output_padding={output_padding}' if self.groups != 1: s += ', groups={groups}' if not self.use_bias: s += ', bias=False' s += ')' return s.format(name=self.__class__.__name__, **self.__dict__)
class GHConv2d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, bias=True): # Init torch module super(GHConv2d, self).__init__() # Init conv params self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = _pair(kernel_size) self.stride = stride self.padding = padding self.dilation = dilation # init constants according to section 5 self.t0 = 1e-5 # Init globals self.sa_mu = Parameter(Tensor(1)) self.sa_logvar = Parameter(Tensor(1)) self.sb_mu = Parameter(Tensor(1)) self.sb_logvar = Parameter(Tensor(1)) # Filter locals self.alpha_mu = Parameter(Tensor(out_channels)) self.alpha_logvar = Parameter(Tensor(out_channels)) self.beta_mu = Parameter(Tensor(out_channels)) self.beta_logvar = Parameter(Tensor(out_channels)) # Weight local self.weight_mu = Parameter(Tensor(out_channels, in_channels, *self.kernel_size)) self.weight_logvar = Parameter(Tensor(out_channels, in_channels, *self.kernel_size)) # Bias local if required self.bias = bias self.bias_mu = Parameter(Tensor(out_channels)) if self.bias else None self.bias_logvar = Parameter(Tensor(out_channels)) if self.bias else None # Set initial parameters self._init_params() # for brevity to conv2d calls self.convargs = [self.stride, self.padding, self.dilation] def _s_mu(self): return 0.5 * (self.sa_mu + self.sb_mu) def _s_var(self): return 0.25 * (self.sa_logvar.exp() + self.sb_logvar.exp()) def _z_var(self): return 0.25 * (self.alpha_logvar.exp() + self.beta_logvar.exp()) def _z_mu(self): return 0.5 * (self.alpha_mu + self.beta_mu) def forward(self, x): # vanilla forward pass if testing if not self.training: expect_z = torch.exp(0.5 * (self._z_var() + self._s_var()) + self._z_mu() + self._s_mu()) post_weight_mu = self.weight_mu * expect_z[:, None, None, None] post_bias_mu = self.bias_mu * expect_z if (self.bias_mu is not None) else None return conv2d(x, post_weight_mu, post_bias_mu, *self.convargs) # compute global shrinkage s_mu = 0.5 * (self.sa_mu + self.sb_mu) s_sig = torch.sqrt(self._s_var()) s = LogNormal(s_mu, s_sig).rsample() # compute filter scales z_mu = self._z_mu() z_var = self._z_var() z = s * LogNormal(z_mu, z_var.sqrt()).rsample()[None, :, None, None] # lognormal out params, local reparameterization trick bvar = self.bias_logvar.exp() if self.bias else None mu_out = conv2d(x, self.weight_mu, self.bias_mu, *self.convargs) * z scale_out = conv2d(x**2, self.weight_logvar.exp(), bvar, *self.convargs) * (z ** 2) # compute output weight distribution, again reparameterised dist_out = Normal(mu_out, scale_out.sqrt()).rsample() # return fully reparameterised forward pass return dist_out def _init_params(self, weight=None, bias=None): # initialisation params - note mean of lognormal is exp(mu + 0.5 *var) init_mu_logvar, init_mu, init_var = -9, 0., 1e-2 # compute xavier initialisation on weights n = self.in_channels * self.kernel_size[0] * self.kernel_size[1] thresh = 1/math.sqrt(n) if weight is not None: self.weight_mu.data = weight else: self.weight_mu.data.uniform_(-thresh, thresh) # init variance according to appendix A self.weight_logvar.data.normal_(init_mu_logvar, init_var) if self.bias: if bias is not None: self.bias_mu.data = bias else: self.bias_mu.data.fill_(0) # biases self.bias_logvar.data.normal_(init_mu_logvar, init_var) # Decomposed prior means => E[z_init] init ~ 1 self.alpha_mu.data.normal_(init_mu, init_var) self.beta_mu.data.normal_(init_mu, init_var) self.sa_mu.data.normal_(init_mu, init_var) self.sb_mu.data.normal_(init_mu, init_var) # Decomposed prior variances self.alpha_logvar.data.normal_(init_mu_logvar, init_var) self.beta_logvar.data.normal_(init_mu_logvar, init_var) self.sa_logvar.data.normal_(init_mu_logvar, init_var) self.sb_logvar.data.normal_(init_mu_logvar, init_var) # KL div for GNH with lognormal scale, normal weight variational posterior def kl_divergence(self): # negative kls, eqns (34-37) neg_kl_s = self._global_negative_kl() neg_kl_ab = self._filter_local_negative_kl() # weight/bias local kl_w = self._conditional_kl_div(self.weight_mu, self.weight_logvar) if self.bias: kl_b = self._conditional_kl_div(self.bias_mu, self.bias_logvar) else: kl_b = 0 return kl_w + kl_b - (neg_kl_s + neg_kl_ab) def _global_negative_kl(self): # hyperparams t0 = self.t0 # const added in every kl div c = 1 + math.log(2) # shape/scale of global scale parameters sa_mu, sb_mu = self.sa_mu, self.sb_mu sa_var, sb_var = self.sa_logvar.exp(), self.sb_logvar.exp() # Eqns (34)(35) kl_sa = math.log(t0) - torch.exp(sa_mu + 0.5 * sa_var)/t0 + 0.5 * (sa_mu + self.sa_logvar + c) kl_sb = 0.5 * (self.sb_logvar - sb_mu + c ) - torch.exp(0.5 * sb_var - sb_mu) return kl_sa + kl_sb def _filter_local_negative_kl(self): # const added in every kl div c = 1 + math.log(2) # hyperparams t0 = self.t0 # filter level shape/scale parameters alpha_mu, beta_mu = self.alpha_mu, self.beta_mu alpha_logvar, beta_logvar = self.alpha_logvar, self.beta_logvar # Eqns (36)(37) kl_alpha = torch.sum(0.5 * (alpha_mu + alpha_logvar + c) - torch.exp(alpha_mu + 0.5 * alpha_logvar.exp())) kl_beta = torch.sum(0.5 * (beta_logvar - beta_mu + c) - torch.exp(0.5 * beta_logvar.exp() - beta_mu)) return kl_alpha + kl_beta @staticmethod def _conditional_kl_div(mu, logvar): # eqn (8) kl_div = -0.5 * logvar + 0.5 * (logvar.exp() + mu ** 2 - 1) return torch.sum(kl_div)
class VDropCentralData(nn.Module): """ Stores data for a set of variational dropout (VDrop) modules in large central tensors. The VDrop modules access the data using views. This makes it possible to operate on all of the data at once, (rather than e.g. 53 times with resnet50). Usage: 1. Instantiate 2. Pass into multiple constructed VDropLinear and VDropConv2d modules 3. Call finalize Before calling forward on the model, call "compute_forward_data". After calling forward on the model, call "clear_forward_data". The parameters are stored in terms of z_mu and z_var rather than w_mu and w_var to support group variational dropout (e.g. to allow for pruning entire channels.) """ def __init__(self, z_logvar_init=-10): super().__init__() self.z_chunk_sizes = [] self.z_logvar_init = z_logvar_init self.z_logvar_min = min(z_logvar_init, -10) self.z_logvar_max = 10. self.epsilon = 1e-8 self.data_views = {} self.modules = [] # Populated during register(), deleted during finalize() self.all_z_mu = [] self.all_z_logvar = [] self.all_num_weights = [] # Populated during finalize() self.z_mu = None self.z_logvar = None self.z_num_weights = None self.threshold = 3 def extra_repr(self): s = f"z_logvar_init={self.z_logvar_init}" return s def __getitem__(self, key): return self.data_views[key] def register(self, module, z_mu, z_logvar, num_weights_per_z=1): self.all_z_mu.append(z_mu.flatten()) self.all_z_logvar.append(z_logvar.flatten()) self.all_num_weights.append(num_weights_per_z) self.modules.append(module) data_index = len(self.z_chunk_sizes) self.z_chunk_sizes.append(z_mu.numel()) return data_index def finalize(self): self.z_mu = Parameter(torch.cat(self.all_z_mu)) self.z_logvar = Parameter(torch.cat(self.all_z_logvar)) self.z_num_weights = torch.tensor(self.all_num_weights, dtype=torch.float).repeat_interleave( torch.tensor(self.z_chunk_sizes)) del self.all_z_mu del self.all_z_logvar del self.all_num_weights def to(self, *args, **kwargs): ret = super().to(*args, **kwargs) self.z_num_weights = self.z_num_weights.to(*args, **kwargs) return ret def compute_forward_data(self): if self.training: self.data_views["z_mu"] = self.z_mu.split(self.z_chunk_sizes) self.data_views["z_var"] = self.z_logvar.exp().split( self.z_chunk_sizes) else: self.data_views["z_mu"] = ( self.z_mu * (self.compute_z_logalpha() < self.threshold).float()).split( self.z_chunk_sizes) def clear_forward_data(self): self.data_views.clear() def compute_z_logalpha(self): return self.z_logvar - (self.z_mu.square() + self.epsilon).log() def regularization(self): return (vdrop_regularization(self.compute_z_logalpha()) * self.z_num_weights).sum() def constrain_parameters(self): self.z_logvar.data.clamp_(min=self.z_logvar_min, max=self.z_logvar_max)
class _ConvNdGroupNJ(Module): """Convolutional Group Normal-Jeffrey's layers (aka Group Variational Dropout). References: [1] Kingma, Diederik P., Tim Salimans, and Max Welling. "Variational dropout and the local reparameterization trick." NIPS (2015). [2] Molchanov, Dmitry, Arsenii Ashukha, and Dmitry Vetrov. "Variational Dropout Sparsifies Deep Neural Networks." ICML (2017). [3] Louizos, Christos, Karen Ullrich, and Max Welling. "Bayesian Compression for Deep Learning." NIPS (2017). """ def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, bias, init_weight, init_bias, cuda=False, clip_var=None): super(_ConvNdGroupNJ, self).__init__() if in_channels % groups != 0: raise ValueError('in_channels must be divisible by groups') if out_channels % groups != 0: raise ValueError('out_channels must be divisible by groups') self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size self.stride = stride self.padding = padding self.dilation = dilation self.transposed = transposed self.output_padding = output_padding self.groups = groups self.cuda = cuda self.clip_var = clip_var self.deterministic = False # flag is used for compressed inference if transposed: self.weight_mu = Parameter(torch.Tensor( in_channels, out_channels // groups, *kernel_size)) self.weight_logvar = Parameter(torch.Tensor( in_channels, out_channels // groups, *kernel_size)) else: self.weight_mu = Parameter(torch.Tensor( out_channels, in_channels // groups, *kernel_size)) self.weight_logvar = Parameter(torch.Tensor( out_channels, in_channels // groups, *kernel_size)) self.bias_mu = Parameter(torch.Tensor(out_channels)) self.bias_logvar = Parameter(torch.Tensor(out_channels)) self.z_mu = Parameter(torch.Tensor(self.out_channels)) self.z_logvar = Parameter(torch.Tensor(self.out_channels)) self.reset_parameters(init_weight, init_bias) # activations for kl self.sigmoid = nn.Sigmoid() self.softplus = nn.Softplus() # numerical stability param self.epsilon = 1e-8 def reset_parameters(self, init_weight, init_bias): # init means n = self.in_channels for k in self.kernel_size: n *= k stdv = 1. / math.sqrt(n) # init means if init_weight is not None: self.weight_mu.data = init_weight else: self.weight_mu.data.uniform_(-stdv, stdv) if init_bias is not None: self.bias_mu.data = init_bias else: self.bias_mu.data.fill_(0) # inti z self.z_mu.data.normal_(1, 1e-2) # init logvars self.z_logvar.data.normal_(-9, 1e-2) self.weight_logvar.data.normal_(-9, 1e-2) self.bias_logvar.data.normal_(-9, 1e-2) def clip_variances(self): if self.clip_var: self.weight_logvar.data.clamp_(max=math.log(self.clip_var)) self.bias_logvar.data.clamp_(max=math.log(self.clip_var)) def get_log_dropout_rates(self): log_alpha = self.z_logvar - torch.log(self.z_mu.pow(2) + self.epsilon) return log_alpha def compute_posterior_params(self): weight_var, z_var = self.weight_logvar.exp(), self.z_logvar.exp() self.post_weight_var = self.z_mu.pow(2) * weight_var + z_var * self.weight_mu.pow(2) + z_var * weight_var self.post_weight_mu = self.weight_mu * self.z_mu return self.post_weight_mu, self.post_weight_var def kl_divergence(self): # KL(q(z)||p(z)) # we use the kl divergence approximation given by [2] Eq.(14) k1, k2, k3 = 0.63576, 1.87320, 1.48695 log_alpha = self.get_log_dropout_rates() KLD = -torch.sum(k1 * self.sigmoid(k2 + k3 * log_alpha) - 0.5 * self.softplus(-log_alpha) - k1) # KL(q(w|z)||p(w|z)) # we use the kl divergence given by [3] Eq.(8) KLD_element = - 0.5 * self.weight_logvar + 0.5 * (self.weight_logvar.exp() + self.weight_mu.pow(2)) - 0.5 KLD += torch.sum(KLD_element) # KL bias KLD_element = - 0.5 * self.bias_logvar + 0.5 * (self.bias_logvar.exp() + self.bias_mu.pow(2)) - 0.5 KLD += torch.sum(KLD_element) return KLD def __repr__(self): s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}' ', stride={stride}') if self.padding != (0,) * len(self.padding): s += ', padding={padding}' if self.dilation != (1,) * len(self.dilation): s += ', dilation={dilation}' if self.output_padding != (0,) * len(self.output_padding): s += ', output_padding={output_padding}' if self.groups != 1: s += ', groups={groups}' if self.bias is None: s += ', bias=False' s += ')' return s.format(name=self.__class__.__name__, **self.__dict__)
class TrimConv2d(Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, log_alpha=-10.0, lamda=0.1, h=0): super(TrimConv2d, self).__init__() if in_channels % groups != 0: raise ValueError("in_channels must be divisible by groups") if out_channels % groups != 0: raise ValueError("out_channels must be divisible by groups") self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size self.stride = stride self.padding = padding self.dilation = dilation self.groups = groups self.weight = Parameter( torch.Tensor(out_channels, in_channels // groups, *kernel_size)) if bias: self.bias = Parameter(torch.Tensor(out_channels)) else: self.register_parameter('bias', None) # For Bayesian inference self.log_alpha = Parameter( torch.randn(*self.weight.size()).fill_(log_alpha)) # KL divergence self.c1 = 1.16145124 self.c2 = -1.50204118 self.c3 = 0.58629921 # Trimming Parameters self.lamda = lamda # regularization parameters self.n_pnt = out_channels - h # the number of penalties if torch.cuda.is_available(): self.regw = torch.ones( out_channels).cuda() # output feature map sparsity else: self.regw = torch.ones(out_channels) self.mask = None self.bias_mask = None self.reset_parameters() def reset_parameters(self): nn.init.kaiming_normal_(self.weight, mode='fan_in') if self.bias is not None: self.bias.data.fill_(0.0) self.regw.data.fill_(self.n_pnt / self.out_channels) self.regw.requires_grad_() def forward(self, inputs): self.log_alpha.data.clamp_(max=0.0) if self.training: # Local Reparametrization Trick on Training Phase mu = F.conv2d(inputs, self.weight) std = torch.sqrt( F.conv2d(inputs**2, self.log_alpha.exp() * self.weight**2) + 1e-8) # This means that sampling only one time for each datapoint eps = torch.randn(*mu.size()) if inputs.is_cuda: eps = eps.cuda() output_size = eps.size() conv_bias = self.bias.unsqueeze(0).repeat( output_size[0], 1).unsqueeze(-1).repeat(1, 1, output_size[2]).unsqueeze(-1).repeat( 1, 1, 1, output_size[3]) return torch.add((mu + std * eps), conv_bias) else: # Test Phase if self.mask is not None: return F.conv2d(inputs, self.mask * self.weight, self.bias) else: return F.conv2d(inputs, self.weight, self.bias) def kld(self): """ Variational Dropout and the Local Reparametrization Trick --> Variational (A2) method """ self.log_alpha.data.clamp_(max=0.0) alpha = self.log_alpha.exp() nkld = 0.5 * self.log_alpha + self.c1 * alpha + self.c2 * alpha**2 + self.c3 * alpha**3 kld = -nkld return kld.mean() / 3 def compute_expected_flops(self): """ To be implemented """ return def reg_theta_loss(self): """ For subgradient method """ eps = torch.randn(*self.weight.size()) if torch.cuda.is_available(): eps = eps.cuda() mu = self.weight std = torch.sqrt(self.log_alpha.exp() * self.weight**2 + 1e-8) samples = mu + std * eps return self.lamda * torch.sum(self.regw.data * conv_norm(samples)) def reg_w_loss(self): eps = torch.randn(*self.weight.size()) if torch.cuda.is_available(): eps = eps.cuda() mu = self.weight.data std = torch.sqrt(self.log_alpha.data.exp() * self.weight.data**2 + 1e-8) samples = mu + std * eps return torch.sum(self.regw * conv_norm(samples)) def set_weight_mask(self): _, in_filters, height, width = self.weight.size() self.mask = self.regw.data < 1.0 if torch.cuda.is_available(): self.mask = self.mask.type(torch.cuda.FloatTensor) else: self.mask = self.mask.type(torch.FloatTensor) self.mask = self.mask.unsqueeze(-1).repeat( 1, in_filters).unsqueeze(-1).repeat(1, 1, height).unsqueeze(-1).repeat( 1, 1, 1, width) def set_bias_mask(self): self.bias_mask = self.regw.data < 1.0 if torch.cuda.is_available(): self.bias_mask = self.bias_mask.type(torch.cuda.FloatTensor) else: self.bias_mask = self.bias_mask.type(torch.FloatTensor) return def apply_mask(self): self.weight.data = self.weight.data * self.mask if self.bias_mask is not None: self.bias.data = self.bias.data * self.bias_mask return def extra_repr(self): return "in_channels={}, out_channels={}, bias={}".format( self.in_channels, self.out_channels, self.bias is not None)
class TrimDense(Module): """ Dense layer for the Trimmed \ell_1 Regularization We treat the network parameters as a Bayesian """ def __init__(self, in_features, out_features, log_alpha='hidden', lamda=0.1, h=0, bias=True): """ Args: in_features: the number of input-neurons out_features: the number of output-neurons h: the number of largest entries which do not be penalized bias: use bias or not """ super(TrimDense, self).__init__() assert in_features >= h # Fully-Connected Layers self.in_features = in_features self.out_features = out_features self.weight = Parameter(torch.Tensor(in_features, out_features)) if bias: self.bias = Parameter(torch.Tensor(out_features)) else: self.register_parameter('bias', None) # For Bayesian Inference self.log_alpha = Parameter( torch.randn(in_features, out_features).fill_(-10.0)) self.c1 = 1.16145124 self.c2 = -1.50204118 self.c3 = 0.58629921 # Trimming Parameters self.lamda = lamda # regularization parameters self.n_pnt = in_features - h # the number of penalties if torch.cuda.is_available(): self.regw = torch.ones(in_features).cuda() # input-neuron sparsity else: self.regw = torch.ones(in_features) self.mask = None self.bias_mask = None self.reset_parameters() def reset_parameters(self): nn.init.kaiming_normal_(self.weight.data, mode='fan_out') if self.bias is not None: self.bias.data.fill_(0) self.regw.data.fill_(self.n_pnt / self.in_features) self.regw.requires_grad_() def forward(self, inputs): self.log_alpha.data.clamp_(max=0.0) if self.training: # Local Reparametrization Trick on Training Phase mu = torch.mm(inputs, self.weight) std = torch.sqrt( torch.mm(inputs**2, self.log_alpha.exp() * self.weight**2) + 1e-8) # This means that sampling only one time for each datapoint eps = torch.randn(*mu.size()) if inputs.is_cuda: eps = eps.cuda() return std * eps + mu + self.bias else: # Test Phase if self.mask is not None: return torch.addmm(self.bias, inputs, self.mask * self.weight) else: return torch.addmm(self.bias, inputs, self.weight) def kld(self): """ Variational Dropout and the Local Reparametrization Trick --> Variational (A2) method """ self.log_alpha.data.clamp_(max=0.0) alpha = self.log_alpha.exp() nkld = 0.5 * self.log_alpha + self.c1 * alpha + self.c2 * alpha**2 + self.c3 * alpha**3 kld = -nkld return kld.mean() / 3 def compute_expected_flops(self): """ To be implemented """ return def reg_theta_loss(self, batch_size=100): """ For subgradient method """ eps = torch.randn(*self.weight.size()) if torch.cuda.is_available(): eps = eps.cuda() mu = self.weight std = torch.sqrt(self.log_alpha.exp() * self.weight**2 + 1e-8) samples = mu + std * eps return self.lamda * torch.sum(self.regw.data * samples.norm(dim=1)) def reg_w_loss(self, batch_size=100): eps = torch.randn(*self.weight.size()) if torch.cuda.is_available(): eps = eps.cuda() mu = self.weight.data std = torch.sqrt(self.log_alpha.data.exp() * self.weight.data**2 + 1e-8) samples = mu + std * eps return torch.sum(self.regw * samples.norm(dim=1)) def set_weight_mask(self): self.mask = (self.regw.data < 1.0).unsqueeze(-1).repeat( 1, self.out_features) if torch.cuda.is_available(): self.mask = self.mask.type(torch.cuda.FloatTensor) else: self.mask = self.mask.type(torch.FloatTensor) def set_bias_mask(self, next_layer_regw): assert len(next_layer_regw) == self.out_features self.bias_mask = next_layer_regw.data < 1.0 if torch.cuda.is_available(): self.bias_mask = self.bias_mask.type(torch.cuda.FloatTensor) else: self.bias_mask = self.bias_mask.type(torch.FloatTensor) return def apply_mask(self): self.weight.data = self.weight.data * self.mask if self.bias_mask is not None: self.bias.data = self.bias.data * self.bias_mask return def extra_repr(self): return "in_features={}, out_features={}, bias={}".format( self.in_features, self.out_features, self.bias is not None)
class LinearGroupNJ(BayesianLayers): """Fully Connected Group Normal-Jeffrey's layer (aka Group Variational Dropout). References: [1] Kingma, Diederik P., Tim Salimans, and Max Welling. "Variational dropout and the local reparameterization trick." NIPS (2015). [2] Molchanov, Dmitry, Arsenii Ashukha, and Dmitry Vetrov. "Variational Dropout Sparsifies Deep Neural Networks." ICML (2017). [3] Louizos, Christos, Karen Ullrich, and Max Welling. "Bayesian Compression for Deep Learning." NIPS (2017). """ def __init__(self, in_features, out_features, cuda=False, init_weight=None, init_bias=None, clip_var=None): super(LinearGroupNJ, self).__init__() self.cuda = cuda self.in_features = in_features self.out_features = out_features self.clip_var = clip_var self.deterministic = False # flag is used for compressed inference # trainable params according to Eq.(6) # dropout params self.z_mu = Parameter(torch.Tensor(in_features)) self.z_logvar = Parameter(torch.Tensor(in_features)) # = z_mu^2 * alpha # weight params self.weight_mu = Parameter(torch.Tensor(out_features, in_features)) self.weight_logvar = Parameter(torch.Tensor(out_features, in_features)) self.bias_mu = Parameter(torch.Tensor(out_features)) self.bias_logvar = Parameter(torch.Tensor(out_features)) # init params either random or with pretrained net self.reset_parameters(init_weight, init_bias) # activations for kl self.sigmoid = nn.Sigmoid() self.softplus = nn.Softplus() # numerical stability param self.epsilon = 1e-8 def reset_parameters(self, init_weight, init_bias): # init means stdv = 1. / math.sqrt(self.weight_mu.size(1)) self.z_mu.data.normal_(1, 1e-2) if init_weight is not None: self.weight_mu.data = torch.Tensor(init_weight) else: self.weight_mu.data.normal_(0, stdv) if init_bias is not None: self.bias_mu.data = torch.Tensor(init_bias) else: self.bias_mu.data.fill_(0) # init logvars self.z_logvar.data.normal_(-9, 1e-2) self.weight_logvar.data.normal_(-9, 1e-2) self.bias_logvar.data.normal_(-9, 1e-2) def clip_variances(self): if self.clip_var: self.weight_logvar.data.clamp_(max=math.log(self.clip_var)) self.bias_logvar.data.clamp_(max=math.log(self.clip_var)) def get_log_dropout_rates(self): log_alpha = self.z_logvar - torch.log(self.z_mu.pow(2) + self.epsilon) return log_alpha def compute_posterior_params(self): weight_var, z_var = self.weight_logvar.exp(), self.z_logvar.exp() self.post_weight_var = self.z_mu.pow(2) * weight_var + z_var * self.weight_mu.pow(2) + z_var * weight_var self.post_weight_mu = self.weight_mu * self.z_mu # print("self.z_mu.pow(2): ", self.z_mu.pow(2).size()) # print("weight_var: ", weight_var.size()) # print("z_var: ", z_var.size()) # print("self.weight_mu.pow(2): ", self.weight_mu.pow(2).size()) # print("weight_var: ", weight_var.size()) # print("post_weight_mu: ", self.post_weight_mu.size()) # print("post_weight_var: ", self.post_weight_var.size()) return self.post_weight_mu, self.post_weight_var def forward(self, x): if self.deterministic: assert self.training == False, "Flag deterministic is True. This should not be used in training." return F.linear(x, self.post_weight_mu, self.bias_mu) batch_size = x.size()[0] # compute z # note that we reparametrise according to [2] Eq. (11) (not [1]) z = reparametrize(self.z_mu.repeat(batch_size, 1), self.z_logvar.repeat(batch_size, 1), sampling=self.training, cuda=self.cuda) # apply local reparametrisation trick see [1] Eq. (6) # to the parametrisation given in [3] Eq. (6) xz = x * z mu_activations = F.linear(xz, self.weight_mu, self.bias_mu) var_activations = F.linear(xz.pow(2), self.weight_logvar.exp(), self.bias_logvar.exp()) return reparametrize(mu_activations, var_activations.log(), sampling=self.training, cuda=self.cuda) def kl_divergence(self): # KL(q(z)||p(z)) # we use the kl divergence approximation given by [2] Eq.(14) k1, k2, k3 = 0.63576, 1.87320, 1.48695 log_alpha = self.get_log_dropout_rates() KLD = -torch.sum(k1 * self.sigmoid(k2 + k3 * log_alpha) - 0.5 * self.softplus(-log_alpha) - k1) # KL(q(w|z)||p(w|z)) # we use the kl divergence given by [3] Eq.(8) KLD_element = -0.5 * self.weight_logvar + 0.5 * (self.weight_logvar.exp() + self.weight_mu.pow(2)) - 0.5 KLD += torch.sum(KLD_element) # KL bias KLD_element = -0.5 * self.bias_logvar + 0.5 * (self.bias_logvar.exp() + self.bias_mu.pow(2)) - 0.5 KLD += torch.sum(KLD_element) return KLD def __repr__(self): return self.__class__.__name__ + ' (' \ + str(self.in_features) + ' -> ' \ + str(self.out_features) + ')'
class VaDE(torch.nn.Module): """Variational Deep Embedding(VaDE). Args: n_classes (int): Number of clusters. data_dim (int): Dimension of observed data. latent_dim (int): Dimension of latent space. """ def __init__(self, n_classes, data_dim, latent_dim): super(VaDE, self).__init__() self._pi = Parameter(torch.zeros(n_classes)) self.mu = Parameter(torch.randn(n_classes, latent_dim)) self.logvar = Parameter(torch.randn(n_classes, latent_dim)) self.encoder = torch.nn.Sequential( torch.nn.Linear(data_dim, 512), torch.nn.ReLU(), torch.nn.Linear(512, 512), torch.nn.ReLU(), torch.nn.Linear(512, 2048), torch.nn.ReLU(), ) self.encoder_mu = torch.nn.Linear(2048, latent_dim) self.encoder_logvar = torch.nn.Linear(2048, latent_dim) self.decoder = torch.nn.Sequential( torch.nn.Linear(latent_dim, 2048), torch.nn.ReLU(), torch.nn.Linear(2048, 512), torch.nn.ReLU(), torch.nn.Linear(512, 512), torch.nn.ReLU(), torch.nn.Linear(512, data_dim), torch.nn.Sigmoid(), ) @property def weights(self): return torch.softmax(self._pi, dim=0) def encode(self, x): h = self.encoder(x) mu = self.encoder_mu(h) logvar = self.encoder_logvar(h) return mu, logvar def decode(self, z): return self.decoder(z) def forward(self, x): mu, logvar = self.encode(x) z = _reparameterize(mu, logvar) recon_x = self.decode(z) return recon_x, mu, logvar def classify(self, x, n_samples=8): with torch.no_grad(): mu, logvar = self.encode(x) z = torch.stack( [_reparameterize(mu, logvar) for _ in range(n_samples)], dim=1) z = z.unsqueeze(2) h = z - self.mu h = torch.exp(-0.5 * torch.sum(h * h / self.logvar.exp(), dim=3)) # Same as `torch.sqrt(torch.prod(self.logvar.exp(), dim=1))` h = h / torch.sum(0.5 * self.logvar, dim=1).exp() p_z_given_c = h / (2 * math.pi) p_z_c = p_z_given_c * self.weights y = p_z_c / torch.sum(p_z_c, dim=2, keepdim=True) y = torch.sum(y, dim=1) pred = torch.argmax(y, dim=1) return pred
class _ConvNdGroupNJ(BayesianLayers): """Convolutional Group Normal-Jeffrey's layers (aka Group Variational Dropout). References: [1] Kingma, Diederik P., Tim Salimans, and Max Welling. "Variational dropout and the local reparameterization trick." NIPS (2015). [2] Molchanov, Dmitry, Arsenii Ashukha, and Dmitry Vetrov. "Variational Dropout Sparsifies Deep Neural Networks." ICML (2017). [3] Louizos, Christos, Karen Ullrich, and Max Welling. "Bayesian Compression for Deep Learning." NIPS (2017). """ def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, bias, init_weight, init_bias, cuda=False, clip_var=None): super(_ConvNdGroupNJ, self).__init__() if in_channels % groups != 0: raise ValueError('in_channels must be divisible by groups') if out_channels % groups != 0: raise ValueError('out_channels must be divisible by groups') self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size self.stride = stride self.padding = padding self.dilation = dilation self.transposed = transposed self.output_padding = output_padding self.groups = groups self.cuda = cuda self.clip_var = clip_var self.deterministic = False # flag is used for compressed inference if transposed: self.weight_mu = Parameter(torch.Tensor( in_channels, out_channels // groups, *kernel_size)) self.weight_logvar = Parameter(torch.Tensor( in_channels, out_channels // groups, *kernel_size)) else: self.weight_mu = Parameter(torch.Tensor( out_channels, in_channels // groups, *kernel_size)) self.weight_logvar = Parameter(torch.Tensor( out_channels, in_channels // groups, *kernel_size)) self.bias_mu = Parameter(torch.Tensor(out_channels)) self.bias_logvar = Parameter(torch.Tensor(out_channels)) self.z_mu = Parameter(torch.Tensor(self.out_channels)) self.z_logvar = Parameter(torch.Tensor(self.out_channels)) self.reset_parameters(init_weight, init_bias) # activations for kl self.sigmoid = nn.Sigmoid() self.softplus = nn.Softplus() # numerical stability param self.epsilon = 1e-8 def reset_parameters(self, init_weight, init_bias): # init means n = self.in_channels for k in self.kernel_size: n *= k stdv = 1. / math.sqrt(n) # init means if init_weight is not None: self.weight_mu.data = init_weight else: self.weight_mu.data.uniform_(-stdv, stdv) if init_bias is not None: self.bias_mu.data = init_bias else: self.bias_mu.data.fill_(0) # inti z self.z_mu.data.normal_(1, 1e-2) # init logvars self.z_logvar.data.normal_(-9, 1e-2) self.weight_logvar.data.normal_(-9, 1e-2) self.bias_logvar.data.normal_(-9, 1e-2) def clip_variances(self): if self.clip_var: self.weight_logvar.data.clamp_(max=math.log(self.clip_var)) self.bias_logvar.data.clamp_(max=math.log(self.clip_var)) def get_log_dropout_rates(self): log_alpha = self.z_logvar - torch.log(self.z_mu.pow(2) + self.epsilon) return log_alpha def compute_posterior_params(self): weight_var, z_var = self.weight_logvar.exp(), self.z_logvar.exp() print("self.z_mu.pow(2): ", self.z_mu.pow(2).size()) print("weight_var: ", weight_var.size()) print("z_var: ", z_var.size()) print("self.weight_mu.pow(2): ", self.weight_mu.pow(2).size()) print("weight_var: ", weight_var.size()) part1 = self.z_mu.pow(2) * weight_var part2 = z_var * self.weight_mu.pow(2) part3 = z_var * weight_var self.post_weight_var = part1 + part2 + part3 self.post_weight_mu = self.weight_mu * self.z_mu print("post_weight_mu: ", self.post_weight_mu.size()) print("post_weight_var: ", self.post_weight_var.size()) return self.post_weight_mu, self.post_weight_var def kl_divergence(self): # KL(q(z)||p(z)) # we use the kl divergence approximation given by [2] Eq.(14) k1, k2, k3 = 0.63576, 1.87320, 1.48695 log_alpha = self.get_log_dropout_rates() KLD = -torch.sum(k1 * self.sigmoid(k2 + k3 * log_alpha) - 0.5 * self.softplus(-log_alpha) - k1) # KL(q(w|z)||p(w|z)) # we use the kl divergence given by [3] Eq.(8) KLD_element = -self.weight_logvar + 0.5 * (self.weight_logvar.exp().pow(2) + self.weight_mu.pow(2)) - 0.5 KLD += torch.sum(KLD_element) # KL bias KLD_element = -self.bias_logvar + 0.5 * (self.bias_logvar.exp().pow(2) + self.bias_mu.pow(2)) - 0.5 KLD += torch.sum(KLD_element) return KLD def __repr__(self): s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}' ', stride={stride}') if self.padding != (0,) * len(self.padding): s += ', padding={padding}' if self.dilation != (1,) * len(self.dilation): s += ', dilation={dilation}' if self.output_padding != (0,) * len(self.output_padding): s += ', output_padding={output_padding}' if self.groups != 1: s += ', groups={groups}' if self.bias is None: s += ', bias=False' s += ')' return s.format(name=self.__class__.__name__, **self.__dict__)
class HSDense(Module): def __init__(self, in_features, out_features, bias=True, prior_std=1., prior_std_z=1., dof=1., **kwargs): super(HSDense, self).__init__() self.in_features = in_features self.out_features = out_features self.prior_std = prior_std self.mean_w = Parameter(torch.Tensor(in_features, out_features)) self.logvar_w = Parameter(torch.Tensor(in_features, out_features)) self.qz_mean = Parameter(torch.Tensor(in_features)) self.qz_logvar = Parameter(torch.Tensor(in_features)) self.dof = dof self.prior_std_z = prior_std_z self.use_bias = False if bias: self.mean_bias = Parameter(torch.Tensor(out_features)) self.logvar_bias = Parameter(torch.Tensor(out_features)) self.use_bias = True self.floatTensor = torch.FloatTensor if not torch.cuda.is_available( ) else torch.cuda.FloatTensor self.reset_parameters() print(self) def reset_parameters(self): init.kaiming_normal_(self.mean_w, mode='fan_out') self.logvar_w.data.normal_(-9., 1e-4) self.qz_mean.data.normal_(math.log(math.exp(1) - 1), 1e-3) self.qz_logvar.data.normal_(math.log(0.1), 1e-4) if self.use_bias: self.mean_bias.data.normal_(0, 1e-2) self.logvar_bias.data.normal_(-9., 1e-4) def constrain_parameters(self, thres_std=1.): self.logvar_w.data.clamp_(max=2. * math.log(thres_std)) if self.use_bias: self.logvar_bias.data.clamp_(max=2. * math.log(thres_std)) def eq_logpw(self): logpw = -.5 * math.log( 2 * math.pi * self.prior_std**2) - .5 * self.logvar_w.exp().div( self.prior_std**2) logpw -= .5 * self.mean_w.pow(2).div(self.prior_std**2) logpb = 0. if self.use_bias: logpb = - .5 * math.log(2 * math.pi * self.prior_std ** 2) - .5 * self.logvar_bias.exp().div \ (self.prior_std ** 2) logpb -= .5 * self.mean_bias.pow(2).div(self.prior_std**2) return torch.sum(logpw) + torch.sum(logpb) def eq_logqw(self): logqw = -torch.sum(.5 * (math.log(2 * math.pi) + self.logvar_w + 1)) logqb = 0. if self.use_bias: logqb = -torch.sum(.5 * (math.log(2 * math.pi) + self.logvar_bias + 1)) return logqw + logqb def kldiv_aux(self): z = self.sample_z(1) z = z.view(self.in_features) logqm = -torch.sum(.5 * (math.log(2 * math.pi) + self.qz_logvar + 1)) logqm = logqm.add(-torch.sum(F.sigmoid(z.exp().add(-1).log()).log())) logpm = torch.sum( 2 * math.lgamma(.5 * (self.dof + 1)) - math.lgamma(.5 * self.dof) - math.log(self.prior_std_z) - .5 * math.log(self.dof * math.pi) - .5 * (self.dof + 1) * torch.log(1. + z.pow(2) / (self.dof * self.prior_std_z**2))) return logpm - logqm def kldiv(self): return self.kldiv_aux() + self.eq_logpw() - self.eq_logqw() def get_eps(self, size): eps = self.floatTensor(size).normal_() eps = Variable(eps) return eps def sample_z(self, batch_size): z = self.qz_mean.view(1, self.in_features) if self.training: eps = self.get_eps(self.floatTensor(batch_size, self.in_features)) z = z + eps.mul( self.qz_logvar.view(1, self.in_features).mul(0.5).exp_()) return F.softplus(z) def sample_W(self): W = self.mean_w if self.training: eps = self.get_eps(self.mean_w.size()) W = W.add(eps.mul(self.logvar_w.mul(0.5).exp_())) return W def sample_b(self): b = self.mean_bias if self.training: eps = self.get_eps(self.mean_bias.size()) b = b.add(eps.mul(self.logvar_bias.mul(0.5).exp_())) return b def get_mean_x(self, input): mean_xin = input.mm(self.mean_w) if self.use_bias: mean_xin = mean_xin.add(self.mean_bias.view(1, self.out_features)) return mean_xin def get_var_x(self, input): var_xin = input.pow(2).mm(self.logvar_w.exp()) if self.use_bias: var_xin = var_xin.add(self.logvar_bias.exp().view( 1, self.out_features)) return var_xin def forward(self, input): # sampling batch_size = input.size(0) z = self.sample_z(batch_size) xin = input.mul(z) mean_xin = self.get_mean_x(xin) output = mean_xin if self.training: var_xin = self.get_var_x(xin) eps = self.get_eps(self.floatTensor(batch_size, self.out_features)) output = output.add(var_xin.sqrt().mul(eps)) return output def __repr__(self): return self.__class__.__name__ + ' (' \ + str(self.in_features) + ' -> ' \ + str(self.out_features) + ', dof: ' \ + str(self.dof) + ', prior_std_z: ' \ + str(self.prior_std_z) + ')'
class FFGaussDense(Module): def __init__(self, in_features, out_features, bias=True, prior_std=1., **kwargs): super(FFGaussDense, self).__init__() self.in_features = in_features self.out_features = out_features self.prior_std = prior_std self.mean_w = Parameter(torch.Tensor(in_features, out_features)) self.logvar_w = Parameter(torch.Tensor(in_features, out_features)) self.use_bias = False if bias: self.mean_bias = Parameter(torch.Tensor(out_features)) self.logvar_bias = Parameter(torch.Tensor(out_features)) self.use_bias = True self.floatTensor = torch.FloatTensor if not torch.cuda.is_available( ) else torch.cuda.FloatTensor self.reset_parameters() print(self) def reset_parameters(self): init.kaiming_normal_(self.mean_w, mode='fan_out') self.logvar_w.data.normal_(-9., 1e-4) if self.use_bias: self.mean_bias.data.zero_() self.logvar_bias.data.normal_(-9., 1e-4) def constrain_parameters(self, thres_std=1.): self.logvar_w.data.clamp_(max=2. * math.log(thres_std)) if self.use_bias: self.logvar_bias.data.clamp_(max=2. * math.log(thres_std)) def eq_logpw(self): logpw = -.5 * math.log( 2 * math.pi * self.prior_std**2) - .5 * self.logvar_w.exp().div( self.prior_std**2) logpw -= .5 * self.mean_w.pow(2).div(self.prior_std**2) logpb = 0. if self.use_bias: logpb = -.5 * math.log(2 * math.pi * self.prior_std** 2) - .5 * self.logvar_bias.exp().div( self.prior_std**2) logpb -= .5 * self.mean_bias.pow(2).div(self.prior_std**2) return torch.sum(logpw) + torch.sum(logpb) def eq_logqw(self): logqw = -torch.sum(.5 * (math.log(2 * math.pi) + self.logvar_w + 1)) logqb = 0. if self.use_bias: logqb = -torch.sum(.5 * (math.log(2 * math.pi) + self.logvar_bias + 1)) return logqw + logqb def kldiv_aux(self): return 0. def kldiv(self): return self.kldiv_aux() + self.eq_logpw() - self.eq_logqw() def get_eps(self, size): eps = self.floatTensor(size).normal_() eps = Variable(eps) return eps def sample_pW(self): return self.floatTensor(self.in_features, self.out_features).normal_() def sample_pb(self): return self.floatTensor(self.out_features).normal_() def sample_W(self): w = self.mean_w if self.training: eps = self.get_eps(self.mean_w.size()) w = w.add(eps.mul(self.logvar_w.mul(0.5).exp_())) return w def sample_b(self): b = self.mean_bias if self.training: eps = self.get_eps(self.mean_bias.size()) b = b.add(eps.mul(self.logvar_bias.mul(0.5).exp_())) return b def get_mean_x(self, input): mean_xin = input.mm(self.mean_w) if self.use_bias: mean_xin = mean_xin.add(self.mean_bias.view(1, self.out_features)) return mean_xin def get_var_x(self, input): var_xin = input.pow(2).mm(self.logvar_w.exp()) if self.use_bias: var_xin = var_xin.add(self.logvar_bias.exp().view( 1, self.out_features)) return var_xin def forward(self, input): batch_size = input.size(0) mean_xin = self.get_mean_x(input) if self.training: var_xin = self.get_var_x(input) eps = self.get_eps(self.floatTensor(batch_size, self.out_features)) output = mean_xin.add(var_xin.sqrt().mul(eps)) else: output = mean_xin return output def __repr__(self): return self.__class__.__name__ + ' (' \ + str(self.in_features) + ' -> ' \ + str(self.out_features) + ', prior_std: ' \ + str(self.prior_std) + ')'
class VDropLinear2(nn.Module): """ A self-contained VDropLinear (doesn't use the VDropCentralData) """ def __init__(self, in_features, out_features, bias=True, w_logvar_init=-10): super().__init__() self.in_features = in_features self.out_features = out_features self.w_logvar_min = min(w_logvar_init, -10) self.w_logvar_max = 10. self.pruned_logvar_sentinel = self.w_logvar_max - 0.00058 self.epsilon = 1e-8 self.w_mu = Parameter(torch.Tensor(self.out_features, self.in_features)) self.w_logvar = Parameter( torch.Tensor(self.out_features, self.in_features)) if bias: self.bias = Parameter(torch.Tensor(out_features)) else: self.bias = None self.w_logvar.data.fill_(w_logvar_init) # Standard nn.Linear initialization. init.kaiming_uniform_(self.w_mu, a=math.sqrt(5)) if self.bias is not None: fan_in, _ = init._calculate_fan_in_and_fan_out(self.w_mu) bound = 1 / math.sqrt(fan_in) init.uniform_(self.bias, -bound, bound) self.tensor_constructor = (torch.FloatTensor if not torch.cuda.is_available() else torch.cuda.FloatTensor) def extra_repr(self): s = f"{self.in_features}, {self.out_features}, " if self.bias is None: s += ", bias=False" return s def get_w_mu(self): return self.w_mu def get_w_var(self): return self.w_logvar.exp() def forward(self, x): if self.training: return vdrop_linear_forward(x, self.get_w_mu, self.get_w_var, self.bias, self.tensor_constructor) else: return F.linear(x, self.get_w_mu(), self.bias) def compute_w_logalpha(self): return self.w_logvar - (self.w_mu.square() + self.epsilon).log() def regularization(self): return vdrop_regularization(self.compute_w_logalpha()).sum() def constrain_parameters(self): self.w_logvar.data.clamp_(min=self.w_logvar_min, max=self.w_logvar_max)
class VariationalDropoutCNN(nn.Module): def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, dilation=1, groups=1, log_sigma2=-8, threshold=3): """ :param input_channel: An int of input channel :param log_sigma2: Initial value of log sigma ^ 2. It is crusial for training since it determines initial value of alpha :param threshold: Value for thresholding of validation. If log_alpha > threshold, then weight is zeroed :param out_channel: An int of output channel """ super(VariationalDropoutCNN, self).__init__() # self.m = img_row self.in_channel = in_channel self.out_channel = out_channel self.kernel_size = kernel_size self.stride = stride self.padding = padding self.dilation = dilation self.groups = groups self.theta = Parameter( t.Tensor(out_channel, in_channel // groups, kernel_size, kernel_size)) self.prior_theta = 0. self.prior_log_sigma2 = -2. # self.bias = Parameter(t.Tensor(out_channel, in_channel // groups, kernel_size, kernel_size)) # self.bias = Parameter(t.Tensor(out_channel, self.m-kernel_size+1, self.m-kernel_size+1)) self.sz = out_channel * (in_channel // groups) * kernel_size**2 self.log_sigma2 = Parameter( t.FloatTensor(out_channel, in_channel // groups, kernel_size, kernel_size).fill_(log_sigma2)) self.s = Parameter(t.Tensor([scale])) self.code = t.Tensor([0.2, 0, -0.2]) self.reset_parameters() self.k = [0.63576, 1.87320, 1.48695] self.threshold = threshold def reset_parameters(self): stdv = 1. / math.sqrt(self.out_channel) self.theta.data.uniform_(-stdv, stdv) # self.bias.data.uniform_(-stdv, stdv) @staticmethod def clip_logsig(input): input = input.masked_fill(input < -10, -10) input = input.masked_fill(input > 1, 1) return input def clip(self): self.log_sigma2.masked_fill(self.log_sigma2 < -10, -10) self.log_sigma2.masked_fill(self.log_sigma2 > 1, 1) self.theta.data = t.where( self.theta < (-0.2 - 0.3679 * t.sqrt(self.log_sigma2.exp())), (-0.2 - 0.3679 * t.sqrt(self.log_sigma2.exp())), self.theta) self.theta.data = t.where( self.theta > (0.2 + 0.3679 * t.sqrt(self.log_sigma2.exp())), (0.2 + 0.3679 * t.sqrt(self.log_sigma2.exp())), self.theta) # self.theta.masked_fill(self.theta < (-0.2-0.3679*t.sqrt(self.log_sigma2.exp())), (-0.2-0.3679*t.sqrt(self.log_sigma2.exp()))) # self.theta.masked_fill(self.theta > (0.2+0.3679*t.sqrt(self.log_sigma2.exp())), (0.2+0.3679*t.sqrt(self.log_sigma2.exp()))) def kld(self, idx): window1 = gaussian_window(self.theta * self.s, 0.2) window2 = gaussian_window(self.theta * self.s, -0.2) log_alpha1 = self.log_sigma2 + 2 * t.log(self.s) - t.log( (self.theta * self.s - 0.2)**2) log_alpha2 = self.log_sigma2 + 2 * t.log(self.s) - t.log( (self.theta * self.s)**2) log_alpha3 = self.log_sigma2 + 2 * t.log(self.s) - t.log( (self.theta * self.s + 0.2)**2) F_KLLU1 = kllu(log_alpha1) F_KLLU2 = kllu(log_alpha2) F_KLLU3 = kllu(log_alpha3) F_KL = F_KLLU1 * window1 + F_KLLU3 * window2 + F_KLLU2 * (1 - window1 - window2) return F_KL.sum() / (self.sz) def forward(self, input, train, noquan): """ :param input: An float tensor with shape of [batch_size, input_size] :return: An float tensor with shape of [batch_size, out_size] and negative layer-kld estimation """ self.clip() c1 = (self.theta * self.s - 0.2)**2 c2 = (self.theta * self.s)**2 c3 = (self.theta * self.s + 0.2)**2 mean = t.min(t.min(c1, c2), c3) c = t.stack((c1, c2, c3), 0) idx = t.argmin(c, 0) if not train and not noquan: """ mask = log_alpha > self.threshold return F.conv2d( input, weight = self.theta.masked_fill(mask, 0), stride=self.stride, padding=self.padding,dilation=self.dilation, groups=self.groups) """ theta_q = self.theta.data.clone() theta_q[:] = self.code[idx].cuda() / self.s mu = F.conv2d(input, weight=theta_q, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups) kld = t.sum((theta_q - self.theta)**2) return mu, kld #+self.bias , kld if noquan: kld = 0 theta_q = self.theta.data.clone() mu = F.conv2d(input, weight=theta_q, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups) return mu, kld #+self.bias , kld kld = _kl_loss(self.theta, self.log_sigma2, self.prior_theta, self.prior_log_sigma2) / self.sz mu = F.conv2d(input, weight=self.theta * self.s, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups) std = t.sqrt( F.conv2d(input**2, weight=self.log_sigma2.exp() * self.s**2, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups) + 1e-6) eps = Variable(t.randn(*mu.size())) if input.is_cuda: eps = eps.cuda() return std * eps + mu, kld # + self.bias , kld def max_alpha(self): log_alpha = self.log_sigma2 - (self.theta - 0.2)**2 return t.max(log_alpha.exp())
class MaskedVDropConv2d(nn.Module): """ A self-contained masked Conv2d (doesn't use the VDropCentralData) """ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, mask=None, w_logvar_init=-10): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = pair(kernel_size) self.stride = pair(stride) self.padding = pair(padding) self.dilation = pair(dilation) self.groups = groups self.w_logvar_min = min(w_logvar_init, -10) self.w_logvar_max = 10. self.pruned_logvar_sentinel = self.w_logvar_max - 0.00058 self.epsilon = 1e-8 self.w_mu = Parameter( torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) self.w_logvar = Parameter( torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) if bias: self.bias = Parameter(torch.Tensor(out_channels)) else: self.bias = None self.w_logvar.data.fill_(w_logvar_init) self.register_buffer( "w_mask", torch.HalfTensor(out_channels, in_channels // groups, *self.kernel_size)) # Standard nn.Conv2d initialization. init.kaiming_uniform_(self.w_mu, a=math.sqrt(5)) if mask is not None: self.w_mask[:] = mask self.w_mu.data *= self.w_mask self.w_logvar.data[self.w_mask == 0.0] = self.pruned_logvar_sentinel else: self.w_mask.fill_(1.0) # Standard nn.Conv2d initialization. if self.bias is not None: fan_in, _ = init._calculate_fan_in_and_fan_out(self.w_mu) bound = 1 / math.sqrt(fan_in) init.uniform_(self.bias, -bound, bound) self.tensor_constructor = (torch.FloatTensor if not torch.cuda.is_available() else torch.cuda.FloatTensor) def extra_repr(self): s = (f"{self.in_channels}, {self.out_channels}, " f"kernel_size={self.kernel_size}, stride={self.stride}") if self.padding != (0, ) * len(self.padding): s += f", padding={self.padding}" if self.dilation != (1, ) * len(self.dilation): s += f", dilation={self.dilation}" if self.groups != 1: s += f", groups={self.groups}" if self.bias is None: s += ", bias=False" return s def get_w_mu(self): return self.w_mu * self.w_mask def get_w_var(self): return self.w_logvar.exp() * self.w_mask def forward(self, x): if self.training: return vdrop_conv_forward(x, self.get_w_mu, self.get_w_var, self.bias, self.stride, self.padding, self.dilation, self.groups, self.tensor_constructor) else: return F.conv2d(x, self.get_w_mu(), self.bias, self.stride, self.padding, self.dilation, self.groups) def compute_w_logalpha(self): return self.w_logvar - (self.w_mu.square() + self.epsilon).log() def regularization(self): return (vdrop_regularization(self.compute_w_logalpha()) * self.w_mask).sum() def constrain_parameters(self): self.w_logvar.data.clamp_(min=self.w_logvar_min, max=self.w_logvar_max)
class VDropConv2d(Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = pair(kernel_size) self.stride = pair(stride) self.padding = pair(padding) self.dilation = pair(dilation) self.groups = groups self.w_mu = Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) init.kaiming_normal_(self.w_mu, mode="fan_out") self.w_logsigma2 = Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) self.w_logsigma2.data.fill_(-10) if bias: self.bias = Parameter(torch.Tensor(out_channels)) self.bias.data.fill_(0) else: self.bias = None self.input_shape = None self.threshold = 3 self.epsilon = 1e-8 self.tensor = (torch.FloatTensor if not torch.cuda.is_available() else torch.cuda.FloatTensor) def compute_mask(self): w_logalpha = self.w_logsigma2 - (self.w_mu ** 2 + self.epsilon).log() return (w_logalpha < self.threshold).float() def forward(self, x): if self.input_shape is None: self.input_shape = x.size() if self.training: y_mu = F.conv2d(x, self.w_mu, self.bias, self.stride, self.padding, self.dilation, self.groups) # Avoid sqrt(0), otherwise a divide-by-zero occurs during backprop. y_sigma = F.conv2d( x ** 2, self.w_logsigma2.exp(), None, self.stride, self.padding, self.dilation, self.groups ).clamp(self.epsilon).sqrt() rv = self.tensor(y_mu.size()).normal_() return y_mu + (rv * y_sigma) else: return F.conv2d(x, self.w_mu * self.compute_mask(), self.bias, self.stride, self.padding, self.dilation, self.groups) def regularization(self): k1, k2, k3 = 0.63576, 1.8732, 1.48695 w_logalpha = self.w_logsigma2 - (self.w_mu ** 2 + self.epsilon).log() return -(k1 * torch.sigmoid(k2 + k3 * w_logalpha) - 0.5 * F.softplus(-w_logalpha) - k1).sum() def get_inference_nonzeros(self): mask = self.compute_mask().int() return mask.sum(dim=tuple(range(1, len(mask.shape)))) def count_inference_flops(self): # For each unit, multiply with its n inputs then do n - 1 additions. # To capture the -1, subtract it, but only in cases where there is at # least one weight. nz_by_unit = self.get_inference_nonzeros() multiplies_per_instance = torch.sum(nz_by_unit) adds_per_instance = multiplies_per_instance - torch.sum(nz_by_unit > 0) # for rows instances = ( (self.input_shape[-2] - self.kernel_size[0] + 2 * self.padding[0]) / self.stride[0]) + 1 # multiplying with cols instances *= ( (self.input_shape[-1] - self.kernel_size[1] + 2 * self.padding[1]) / self.stride[1]) + 1 multiplies = multiplies_per_instance * instances adds = adds_per_instance * instances return multiplies.item(), adds.item() def weight_size(self): return self.w_mu.size()
class discrete_vision_actor_critic_Net(nn.Module): def __init__(self, s_dim, n_actions, latent_dim, n_heads=8, init_log_alpha=0.0, parallel=True, lr=1e-4, lr_alpha=1e-4, lr_actor=1e-4): super().__init__() self.s_dim = s_dim self.n_actions = n_actions self._parallel = parallel self.q = vision_multihead_dueling_q_Net(s_dim, latent_dim, n_actions, n_heads, lr) self.q_target = vision_multihead_dueling_q_Net(s_dim, latent_dim, n_actions, n_heads, lr) self.update(rate=1.0) self.actor = vision_softmax_policy_Net(s_dim, latent_dim, n_actions, noisy=False, lr=lr_alpha) self.log_alpha = Parameter(torch.Tensor(1)) nn.init.constant_(self.log_alpha, init_log_alpha) self.alpha_optimizer = Adam([self.log_alpha], lr=lr_alpha) def forward(self): pass def evaluate_critic(self, inner_state, outer_state, next_inner_state, next_outer_state): q = self.q(inner_state, outer_state) next_q = self.q_target(next_inner_state, next_outer_state) next_pi, next_log_pi = self.actor(next_inner_state, next_outer_state) log_alpha = self.log_alpha.view(-1, 1) return q, next_q, next_pi, next_log_pi, log_alpha def evaluate_actor(self, inner_state, outer_state): q = self.q(inner_state, outer_state) pi, log_pi = self.actor(inner_state, outer_state) return q, pi, log_pi def sample_action(self, inner_state, outer_state, explore=True): PA_s = self.actor(inner_state.view(1, -1), outer_state.unsqueeze(0))[0].squeeze(0).view(-1) assert torch.all(PA_s == PA_s), 'Boom. Capoot.' if explore: A = Categorical(probs=PA_s).sample().item() else: tie_breaking_dist = torch.isclose(PA_s, PA_s.max()).float() tie_breaking_dist /= tie_breaking_dist.sum() A = Categorical(probs=tie_breaking_dist).sample().item() return A, PA_s.detach().cpu().numpy() def update(self, rate=5e-3): updateNet(self.q_target, self.q, rate) def get_alpha(self): return self.log_alpha.exp().item()
class GNJConv2d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, bias=True): # Init torch module super(GNJConv2d, self).__init__() # Init conv params self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = _pair(kernel_size) self.stride = stride self.padding = padding self.dilation = dilation # Init filter latents self.weight_mu = Parameter(Tensor(out_channels, in_channels, *self.kernel_size)) self.weight_logvar = Parameter(Tensor(out_channels, in_channels, *self.kernel_size)) self.bias = bias self.bias_mu = Parameter(Tensor(out_channels)) if self.bias else None self.bias_logvar = Parameter(Tensor(out_channels)) if self.bias else None # Init prior latents self.z_mu = Parameter(Tensor(out_channels)) self.z_logvar = Parameter(Tensor(out_channels)) # Set initial parameters self._init_params() # for brevity to conv2d calls self.convargs = [self.stride, self.padding, self.dilation] # util activations self.sigmoid = Sigmoid() self.softplus = Softplus() # forward network pass def forward(self, x): # vanilla forward pass if testing if not self.training: post_weight_mu = self.weight_mu * self.z_mu[:, None, None, None] post_bias_mu = self.bias_mu * self.z_mu if (self.bias_mu is not None) else None return conv2d(x, post_weight_mu, post_bias_mu, *self.convargs) #batch_size = x.size()[0] # unpack mean/std mu = self.z_mu std = torch.exp(0.5 * self.z_logvar) # rsample: sample scale prior with reparam trick z = Normal(mu, std).rsample()[None, :, None, None] # weights and biases for variance estimation weight_v = self.weight_logvar.exp() bias_v = self.bias_logvar.exp() if self.bias else None # parameterise output distribution mu_out = conv2d(x, self.weight_mu, self.bias_mu, *self.convargs) * z var_out = conv2d(x**2, weight_v, bias_v, *self.convargs) * (z ** 2) # Init out, note multiplicative noise==variational dropout dist_out = Normal(mu_out, var_out.sqrt()).rsample() #dist_out = self.reparam(mu_out*z, (var_out * z.pow(2)).log()) return dist_out def _init_params(self, weight=None, bias=None): n = self.in_channels * self.kernel_size[0] * self.kernel_size[1] thresh = 1/math.sqrt(n) # weights self.weight_logvar.data.normal_(-9, 1e-2) if weight is not None: self.weight_mu.data = weight else: self.weight_mu.data.uniform_(-thresh, thresh) if self.bias: # biases self.bias_logvar.data.normal_(-9, 1e-2) if bias is not None: self.bias_mu.data = bias else: self.bias_mu.data.fill_(0) # priors self.z_mu.data.normal_(1, 1e-2) self.z_logvar.data.normal_(-9, 1e-2) # shape,scale family reparameterization trick (rsample does this?) def reparam(self, mu, logvar): std = torch.exp(0.5 * logvar) # check for cuda #tenv = torch.cuda if cuda else torch # draw from normal eps = torch.FloatTensor(std.size()).normal_() return mu + eps * std # KL div for GNJ w. Normal approx posterior def kl_divergence(self): # for brevity in kl_scale sg = self.sigmoid sp = self.softplus # Approximation parameters. Molchanov et al. k1, k2, k3 = 0.63576, 1.87320, 1.48695 log_alpha = self._log_alpha() kl_scale = torch.sum(0.5 * sp(-log_alpha) + k1 - k1 * sg(k2 + k3 * log_alpha)) kl_weight = self._conditional_kl_div(self.weight_mu, self.weight_logvar) kl_bias = self._conditional_kl_div(self.bias_mu, self.bias_logvar) if self.bias else 0 return kl_scale + kl_weight + kl_bias @staticmethod def _conditional_kl_div(mu, logvar): # (8) Weight/bias divergence KL(q(w|z)||p(w|z)) kl_div = -0.5 * logvar + 0.5 * (logvar.exp() + mu ** 2 - 1) return torch.sum(kl_div) # effective dropout rate def _log_alpha(self): epsilon = 1e-8 log_a = self.z_logvar - torch.log(self.z_mu ** 2 + epsilon) return log_a
class LinearGroupNJ(Module): """Fully Connected Group Normal-Jeffrey's layer (aka Group Variational Dropout). References: [1] Kingma, Diederik P., Tim Salimans, and Max Welling. "Variational dropout and the local reparameterization trick." NIPS (2015). [2] Molchanov, Dmitry, Arsenii Ashukha, and Dmitry Vetrov. "Variational Dropout Sparsifies Deep Neural Networks." ICML (2017). [3] Louizos, Christos, Karen Ullrich, and Max Welling. "Bayesian Compression for Deep Learning." NIPS (2017). """ def __init__(self, in_features, out_features, cuda=False, init_weight=None, init_bias=None, clip_var=None): super(LinearGroupNJ, self).__init__() self.cuda = cuda self.in_features = in_features self.out_features = out_features self.clip_var = clip_var self.deterministic = False # flag is used for compressed inference # trainable params according to Eq.(6) # dropout params self.z_mu = Parameter(torch.Tensor(in_features)) self.z_logvar = Parameter(torch.Tensor(in_features)) # = z_mu^2 * alpha # weight params self.weight_mu = Parameter(torch.Tensor(out_features, in_features)) self.weight_logvar = Parameter(torch.Tensor(out_features, in_features)) self.bias_mu = Parameter(torch.Tensor(out_features)) self.bias_logvar = Parameter(torch.Tensor(out_features)) # init params either random or with pretrained net self.reset_parameters(init_weight, init_bias) # activations for kl self.sigmoid = nn.Sigmoid() self.softplus = nn.Softplus() # numerical stability param self.epsilon = 1e-8 def reset_parameters(self, init_weight, init_bias): # init means stdv = 1. / math.sqrt(self.weight_mu.size(1)) self.z_mu.data.normal_(1, 1e-2) if init_weight is not None: self.weight_mu.data = torch.Tensor(init_weight) else: self.weight_mu.data.normal_(0, stdv) if init_bias is not None: self.bias_mu.data = torch.Tensor(init_bias) else: self.bias_mu.data.fill_(0) # init logvars self.z_logvar.data.normal_(-9, 1e-2) self.weight_logvar.data.normal_(-9, 1e-2) self.bias_logvar.data.normal_(-9, 1e-2) def clip_variances(self): if self.clip_var: self.weight_logvar.data.clamp_(max=math.log(self.clip_var)) self.bias_logvar.data.clamp_(max=math.log(self.clip_var)) def get_log_dropout_rates(self): log_alpha = self.z_logvar - torch.log(self.z_mu.pow(2) + self.epsilon) return log_alpha def compute_posterior_params(self): weight_var, z_var = self.weight_logvar.exp(), self.z_logvar.exp() self.post_weight_var = self.z_mu.pow(2) * weight_var + z_var * self.weight_mu.pow(2) + z_var * weight_var self.post_weight_mu = self.weight_mu * self.z_mu return self.post_weight_mu, self.post_weight_var def forward(self, x): if self.deterministic: assert self.training == False, "Flag deterministic is True. This should not be used in training." return F.linear(x, self.post_weight_mu, self.bias_mu) batch_size = x.size()[0] # compute z # note that we reparametrise according to [2] Eq. (11) (not [1]) z = reparametrize(self.z_mu.repeat(batch_size, 1), self.z_logvar.repeat(batch_size, 1), sampling=self.training) # apply local reparametrisation trick see [1] Eq. (6) # to the parametrisation given in [3] Eq. (6) xz = x * z mu_activations = F.linear(xz, self.weight_mu, self.bias_mu) var_activations = F.linear(xz.pow(2), self.weight_logvar.exp(), self.bias_logvar.exp()) return reparametrize(mu_activations, var_activations.log(), sampling=self.training) def kl_divergence(self): # KL(q(z)||p(z)) # we use the kl divergence approximation given by [2] Eq.(14) k1, k2, k3 = 0.63576, 1.87320, 1.48695 log_alpha = self.get_log_dropout_rates() KLD = -torch.sum(k1 * self.sigmoid(k2 + k3 * log_alpha) - 0.5 * self.softplus(-log_alpha) - k1) # KL(q(w|z)||p(w|z)) # we use the kl divergence given by [3] Eq.(8) KLD_element = -0.5 * self.weight_logvar + 0.5 * (self.weight_logvar.exp() + self.weight_mu.pow(2)) - 0.5 KLD += torch.sum(KLD_element) # KL bias KLD_element = -0.5 * self.bias_logvar + 0.5 * (self.bias_logvar.exp() + self.bias_mu.pow(2)) - 0.5 KLD += torch.sum(KLD_element) return KLD def __repr__(self): return self.__class__.__name__ + ' (' \ + str(self.in_features) + ' -> ' \ + str(self.out_features) + ')'
class FFGaussConv2d(Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, prior_std=1, **kwargs): super(FFGaussConv2d, self).__init__() if in_channels % groups != 0: raise ValueError('in_channels must be divisible by groups') if out_channels % groups != 0: raise ValueError('out_channels must be divisible by groups') self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = pair(kernel_size) self.stride = pair(stride) self.padding = pair(padding) self.dilation = pair(dilation) self.output_padding = pair(0) self.groups = groups self.prior_std = prior_std self.use_bias = False self.mean_w = Parameter( torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) self.logvar_w = Parameter( torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) if bias: self.mean_bias = Parameter(torch.Tensor(out_channels)) self.logvar_bias = Parameter(torch.Tensor(out_channels)) self.use_bias = True self.floatTensor = torch.FloatTensor if not torch.cuda.is_available( ) else torch.cuda.FloatTensor self.reset_parameters() print(self) def reset_parameters(self): init.kaiming_normal_(self.mean_w, mode='fan_in') self.logvar_w.data.normal_(-9., 1e-4) if self.use_bias: self.mean_bias.data.zero_() self.logvar_bias.data.normal_(-9., 1e-4) def constrain_parameters(self, thres_std=1.): self.logvar_w.data.clamp_(max=2. * math.log(thres_std)) if self.use_bias: self.logvar_bias.data.clamp_(max=2. * math.log(thres_std)) def eq_logpw(self): logpw = -.5 * math.log( 2 * math.pi * self.prior_std**2) - .5 * self.logvar_w.exp().div( self.prior_std**2) logpw -= .5 * self.mean_w.pow(2).div(self.prior_std**2) logpb = 0. if self.use_bias: logpb = -.5 * math.log(2 * math.pi * self.prior_std** 2) - .5 * self.logvar_bias.exp().div( self.prior_std**2) logpb -= .5 * self.mean_bias.pow(2).div(self.prior_std**2) return torch.sum(logpw) + torch.sum(logpb) def eq_logqw(self): logqw = -torch.sum(.5 * (math.log(2 * math.pi) + self.logvar_w + 1)) logqb = 0. if self.use_bias: logqb = -torch.sum(.5 * (math.log(2 * math.pi) + self.logvar_bias + 1)) return logqw + logqb def kldiv_aux(self): return 0. def kldiv(self): return self.kldiv_aux() + self.eq_logpw() - self.eq_logqw() def get_eps(self, size): eps = self.floatTensor(size).normal_() eps = Variable(eps) return eps def sample_W(self): W = self.mean_w if self.training: eps = self.get_eps(self.mean_w.size()) W = W.add(eps.mul(self.logvar_w.mul(0.5).exp_())) return W def sample_b(self): if not self.use_bias: return None b = self.mean_bias if self.training: eps = self.get_eps(self.mean_bias.size()) b = b.add(eps.mul(self.logvar_bias.mul(0.5).exp_())) return b def forward(self, input_): W = self.sample_W() b = self.sample_b() return F.conv2d(input_, W, b, self.stride, self.padding, self.dilation, self.groups) def __repr__(self): s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}' ', stride={stride}, prior_std={prior_std}') if self.padding != (0, ) * len(self.padding): s += ', padding={padding}' if self.dilation != (1, ) * len(self.dilation): s += ', dilation={dilation}' if self.output_padding != (0, ) * len(self.output_padding): s += ', output_padding={output_padding}' if self.groups != 1: s += ', groups={groups}' if not self.use_bias: s += ', bias=False' s += ')' return s.format(name=self.__class__.__name__, **self.__dict__)
class VariationalDropout(nn.Module): def __init__(self, input_size, out_size, log_sigma2=-8, threshold=3): """ :param input_size: An int of input size :param log_sigma2: Initial value of log sigma ^ 2. It is crusial for training since it determines initial value of alpha :param threshold: Value for thresholding of validation. If log_alpha > threshold, then weight is zeroed :param out_size: An int of output size """ super(VariationalDropout, self).__init__() self.input_size = input_size self.out_size = out_size self.theta = Parameter(t.FloatTensor(input_size, out_size)) self.bias = Parameter(t.Tensor(out_size)) self.prior_theta = 0. self.prior_log_sigma2 = -2. self.log_sigma2 = Parameter( t.FloatTensor(input_size, out_size).fill_(log_sigma2)) self.sz = input_size * out_size self.s = Parameter(t.Tensor([scale])) self.code = t.Tensor([0.2, 0, -0.2]) self.reset_parameters() self.k = [0.63576, 1.87320, 1.48695] self.threshold = threshold def reset_parameters(self): stdv = 1. / math.sqrt(self.out_size) self.theta.data.uniform_(-stdv, stdv) self.bias.data.uniform_(-stdv, stdv) @staticmethod def clip(self): self.log_sigma2.masked_fill(self.log_sigma2 < -10, -10) self.log_sigma2.masked_fill(self.log_sigma2 > 1, 1) self.theta.data = t.where( self.theta < (-0.2 - 0.3679 * t.sqrt(self.log_sigma2.exp())), (-0.2 - 0.3679 * t.sqrt(self.log_sigma2.exp())), self.theta) self.theta.data = t.where( self.theta > (0.2 + 0.3679 * t.sqrt(self.log_sigma2.exp())), (0.2 + 0.3679 * t.sqrt(self.log_sigma2.exp())), self.theta) # self.theta.masked_fill(self.theta < (-0.2-0.3679*t.sqrt(self.log_sigma2.exp())), (-0.2-0.3679*t.sqrt(self.log_sigma2.exp()))) # self.theta.masked_fill(self.theta > (0.2+0.3679*t.sqrt(self.log_sigma2.exp())), (0.2+0.3679*t.sqrt(self.log_sigma2.exp()))) def clip__(input, to=8): input = input.masked_fill(input < -to, -to) input = input.masked_fill(input > to, to) return input # def kllu(self,log_alpha): # first_term = self.k[0] * F.sigmoid(self.k[1] + self.k[2] * log_alpha) # second_term = 0.5 * t.log(1 + t.exp(-log_alpha)) # return -(first_term - second_term - self.k[0]) def kld(self, mean, idx): window1 = gaussian_window(mean * self.s, 0.2) window2 = gaussian_window(mean * self.s, -0.2) log_alpha1 = self.log_sigma2 + 2 * t.log(self.s) - t.log( (mean * self.s - 0.2)**2) log_alpha2 = self.log_sigma2 + 2 * t.log(self.s) - t.log( (mean * self.s)**2) log_alpha3 = self.log_sigma2 + 2 * t.log(self.s) - t.log( (mean * self.s + 0.2)**2) F_KLLU1 = kllu(log_alpha1) F_KLLU2 = kllu(log_alpha2) F_KLLU3 = kllu(log_alpha3) # print(F_KLLU1) # print(F_KLLU2) # print(F_KLLU3) # print(hi) F_KL = F_KLLU1 * window1 + F_KLLU3 * window2 + F_KLLU2 * (1 - window1 - window2) return F_KL.sum() / (self.sz) def forward(self, input, train, noquan): """ :param input: An float tensor with shape of [batch_size, input_size] :return: An float tensor with shape of [batch_size, out_size] and negative layer-kld estimation """ self.clip(self) c1 = (self.theta * self.s - 0.2)**2 c2 = (self.theta * self.s)**2 c3 = (self.theta * self.s + 0.2)**2 mean = t.min(t.min(c1, c2), c3) c = t.stack((c1, c2, c3), 0) idx = t.argmin(c, 0) # print(idx) if not train and not noquan: """ mask = log_alpha > self.threshold return t.addmm(self.bias, input, self.theta.masked_fill(mask, 0)) """ theta_q = self.theta.data.clone() theta_q[:] = self.code[idx].cuda() / self.s # mask = log_alpha > self.threshold mu = t.mm(input, theta_q) kld = t.sum((theta_q - self.theta)**2) return mu + self.bias, kld if noquan: kld = 0 """ mask = log_alpha > self.threshold return t.addmm(self.bias, input, self.theta.masked_fill(mask, 0)) """ theta_q = self.theta.data.clone() mu = t.mm(input, theta_q) return mu + self.bias, kld kld = _kl_loss(self.theta, self.log_sigma2, self.prior_theta, self.prior_log_sigma2) / self.sz mu = t.mm(input, self.theta * self.s) std = t.sqrt(t.mm(input**2, self.s**2 * self.log_sigma2.exp()) + 1e-6) eps = Variable(t.randn(*mu.size())) if input.is_cuda: eps = eps.cuda() return std * eps + mu + self.bias, kld def max_alpha(self): log_alpha = self.log_sigma2 - self.theta**2 return t.max(log_alpha.exp())
class VDropConv2d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = pair(kernel_size) self.stride = pair(stride) self.padding = pair(padding) self.dilation = pair(dilation) self.groups = groups self.weight = Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) init.kaiming_normal_(self.weight, mode="fan_out") self.w_logvar = Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) self.w_logvar.data.fill_(-10) if bias: self.bias = Parameter(torch.Tensor(out_channels)) self.bias.data.fill_(0) else: self.bias = None self.input_shape = None self.threshold = 3 self.epsilon = 1e-8 self.tensor_constructor = (torch.FloatTensor if not torch.cuda.is_available() else torch.cuda.FloatTensor) def extra_repr(self): s = (f"{self.in_channels}, {self.out_channels}, " f"kernel_size={self.kernel_size}, stride={self.stride}") if self.padding != (0,) * len(self.padding): s += f", padding={self.padding}" if self.dilation != (1,) * len(self.dilation): s += f", dilation={self.dilation}" if self.groups != 1: s += f", groups={self.groups}" if self.bias is None: s += ", bias=False" return s def constrain_parameters(self): self.w_logvar.data.clamp_(min=-10., max=10.) def compute_mask(self): w_logalpha = self.w_logvar - (self.weight ** 2 + self.epsilon).log() return (w_logalpha < self.threshold).float() def forward(self, x): if self.input_shape is None: self.input_shape = x.size() if self.training: return vdrop_conv_forward(x, lambda: self.weight, lambda: self.w_logvar.exp(), self.bias, self.stride, self.padding, self.dilation, self.groups, self.tensor_constructor, self.epsilon) else: return F.conv2d(x, self.weight * self.compute_mask(), self.bias, self.stride, self.padding, self.dilation, self.groups) def regularization(self): w_logalpha = self.w_logvar - (self.weight ** 2 + self.epsilon).log() return vdrop_regularization(w_logalpha).sum() def get_inference_nonzeros(self): mask = self.compute_mask().int() return mask.sum(dim=tuple(range(1, len(mask.shape)))) def count_inference_flops(self): # For each unit, multiply with its n inputs then do n - 1 additions. # To capture the -1, subtract it, but only in cases where there is at # least one weight. nz_by_unit = self.get_inference_nonzeros() multiplies_per_instance = torch.sum(nz_by_unit) adds_per_instance = multiplies_per_instance - torch.sum(nz_by_unit > 0) # for rows instances = ( (self.input_shape[-2] - self.kernel_size[0] + 2 * self.padding[0]) / self.stride[0]) + 1 # multiplying with cols instances *= ( (self.input_shape[-1] - self.kernel_size[1] + 2 * self.padding[1]) / self.stride[1]) + 1 multiplies = multiplies_per_instance * instances adds = adds_per_instance * instances return multiplies.item(), adds.item() def weight_size(self): return self.weight.size()