def __init__(self, in_features, out_features, bias=True, prior_sigma_1=0.1, prior_sigma_2=0.4, prior_pi=1, posterior_mu_init=0, posterior_rho_init=-7.0, freeze=False, prior_dist=None): super().__init__() #our main parameters self.in_features = in_features self.out_features = out_features self.bias = bias self.freeze = freeze self.posterior_mu_init = posterior_mu_init self.posterior_rho_init = posterior_rho_init #parameters for the scale mixture prior self.prior_sigma_1 = prior_sigma_1 self.prior_sigma_2 = prior_sigma_2 self.prior_pi = prior_pi self.prior_dist = prior_dist # Variational weight parameters and sample self.weight_mu = nn.Parameter( torch.Tensor(out_features, in_features).normal_(posterior_mu_init, 0.1)) self.weight_rho = nn.Parameter( torch.Tensor(out_features, in_features).normal_(posterior_rho_init, 0.1)) self.weight_sampler = TrainableRandomDistribution( self.weight_mu, self.weight_rho) # Variational bias parameters and sample self.bias_mu = nn.Parameter( torch.Tensor(out_features).normal_(posterior_mu_init, 0.1)) self.bias_rho = nn.Parameter( torch.Tensor(out_features).normal_(posterior_rho_init, 0.1)) self.bias_sampler = TrainableRandomDistribution( self.bias_mu, self.bias_rho) # Priors (as BBP paper) self.weight_prior_dist = PriorWeightDistribution(self.prior_pi, self.prior_sigma_1, self.prior_sigma_2, dist=self.prior_dist) self.bias_prior_dist = PriorWeightDistribution(self.prior_pi, self.prior_sigma_1, self.prior_sigma_2, dist=self.prior_dist) self.log_prior = 0 self.log_variational_posterior = 0
def __init__(self, in_channels, out_channels, kernel_size, groups = 1, stride = 1, padding = 0, dilation = 1, bias=True, prior_sigma_1 = 0.1, prior_sigma_2 = 0.002, prior_pi = 1, posterior_mu_init = 0, posterior_rho_init = -7.0, freeze = False, prior_dist = None): super().__init__() #our main parameters self.in_channels = in_channels self.out_channels = out_channels self.freeze = freeze self.kernel_size = kernel_size self.groups = groups self.stride = stride self.padding = padding self.dilation = dilation self.bias = bias self.posterior_mu_init = posterior_mu_init self.posterior_rho_init = posterior_rho_init #parameters for the scale mixture prior self.prior_sigma_1 = prior_sigma_1 self.prior_sigma_2 = prior_sigma_2 self.prior_pi = prior_pi self.prior_dist = prior_dist #our weights self.weight_mu = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, kernel_size).normal_(posterior_mu_init, 0.1)) self.weight_rho = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, kernel_size).normal_(posterior_rho_init, 0.1)) self.weight_sampler = TrainableRandomDistribution(self.weight_mu, self.weight_rho) #our biases self.bias_mu = nn.Parameter(torch.Tensor(out_channels).normal_(posterior_mu_init, 0.1)) self.bias_rho = nn.Parameter(torch.Tensor(out_channels).normal_(posterior_rho_init, 0.1)) self.bias_sampler = TrainableRandomDistribution(self.bias_mu, self.bias_rho) # Priors (as BBP paper) self.weight_prior_dist = PriorWeightDistribution(self.prior_pi, self.prior_sigma_1, self.prior_sigma_2, dist = self.prior_dist) self.bias_prior_dist = PriorWeightDistribution(self.prior_pi, self.prior_sigma_1, self.prior_sigma_2, dist = self.prior_dist) self.log_prior = 0 self.log_variational_posterior = 0
def __init__(self, num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2., scale_grad_by_freq=False, sparse=False, prior_sigma_1=0.1, prior_sigma_2=0.002, prior_pi=1, posterior_mu_init=0, posterior_rho_init=-6.0, freeze=False, prior_dist=None): super().__init__() self.freeze = freeze #parameters for the scale mixture prior self.prior_sigma_1 = prior_sigma_1 self.prior_sigma_2 = prior_sigma_2 self.posterior_mu_init = posterior_mu_init self.posterior_rho_init = posterior_rho_init self.prior_pi = prior_pi self.prior_dist = prior_dist self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim self.padding_idx = padding_idx self.max_norm = max_norm self.norm_type = norm_type self.scale_grad_by_freq = scale_grad_by_freq self.sparse = sparse # Variational weight parameters and sample self.weight_mu = nn.Parameter( torch.Tensor(num_embeddings, embedding_dim).normal_(posterior_mu_init, 0.1)) self.weight_rho = nn.Parameter( torch.Tensor(num_embeddings, embedding_dim).normal_(posterior_rho_init, 0.1)) self.weight_sampler = TrainableRandomDistribution( self.weight_mu, self.weight_rho) # Priors (as BBP paper) self.weight_prior_dist = PriorWeightDistribution(self.prior_pi, self.prior_sigma_1, self.prior_sigma_2, dist=self.prior_dist) self.log_prior = 0 self.log_variational_posterior = 0
def test_scale_mixture_prior(self): mu = torch.Tensor(10, 10).uniform_(-1, 1) rho = torch.Tensor(10, 10).uniform_(-1, 1) dist = TrainableRandomDistribution(mu, rho) s1 = dist.sample() log_posterior = dist.log_posterior() prior_dist = PriorWeightDistribution(0.5, 1, .002) log_prior = prior_dist.log_prior(s1) #print(log_prior) #print(log_posterior) self.assertEqual(log_prior == log_prior, torch.tensor(True)) self.assertEqual(log_posterior <= log_posterior - log_prior, torch.tensor(True)) pass
class BayesianGRU(BayesianRNN): """ Bayesian GRU layer, implements the linear layer proposed on Weight Uncertainity on Neural Networks (Bayes by Backprop paper). Its objective is be interactable with torch nn.Module API, being able even to be chained in nn.Sequential models with other non-this-lib layers parameters: in_fetaures: int -> incoming features for the layer out_features: int -> output features for the layer bias: bool -> whether the bias will exist (True) or set to zero (False) prior_sigma_1: float -> prior sigma on the mixture prior distribution 1 prior_sigma_2: float -> prior sigma on the mixture prior distribution 2 prior_pi: float -> pi on the scaled mixture prior posterior_mu_init float -> posterior mean for the weight mu init posterior_rho_init float -> posterior mean for the weight rho init freeze: bool -> wheter the model will start with frozen(deterministic) weights, or not """ def __init__(self, in_features, out_features, bias=True, prior_sigma_1=0.1, prior_sigma_2=0.002, prior_pi=1, posterior_mu_init=0, posterior_rho_init=-6.0, freeze=False, prior_dist=None, **kwargs): super().__init__(**kwargs) self.in_features = in_features self.out_features = out_features self.use_bias = bias self.freeze = freeze self.posterior_mu_init = posterior_mu_init self.posterior_rho_init = posterior_rho_init self.prior_sigma_1 = prior_sigma_1 self.prior_sigma_2 = prior_sigma_2 self.prior_pi = prior_pi self.prior_dist = prior_dist # Variational weight parameters and sample for weight ih self.weight_ih_mu = nn.Parameter( torch.Tensor(in_features, out_features * 4).normal_(posterior_mu_init, 0.1)) self.weight_ih_rho = nn.Parameter( torch.Tensor(in_features, out_features * 4).normal_(posterior_rho_init, 0.1)) self.weight_ih_sampler = TrainableRandomDistribution( self.weight_ih_mu, self.weight_ih_rho) self.weight_ih = None # Variational weight parameters and sample for weight hh self.weight_hh_mu = nn.Parameter( torch.Tensor(out_features, out_features * 4).normal_(posterior_mu_init, 0.1)) self.weight_hh_rho = nn.Parameter( torch.Tensor(out_features, out_features * 4).normal_(posterior_rho_init, 0.1)) self.weight_hh_sampler = TrainableRandomDistribution( self.weight_hh_mu, self.weight_hh_rho) self.weight_hh = None # Variational weight parameters and sample for bias self.bias_mu = nn.Parameter( torch.Tensor(out_features * 4).normal_(posterior_mu_init, 0.1)) self.bias_rho = nn.Parameter( torch.Tensor(out_features * 4).normal_(posterior_rho_init, 0.1)) self.bias_sampler = TrainableRandomDistribution( self.bias_mu, self.bias_rho) self.bias = None #our prior distributions self.weight_ih_prior_dist = PriorWeightDistribution( self.prior_pi, self.prior_sigma_1, self.prior_sigma_2, dist=self.prior_dist) self.weight_hh_prior_dist = PriorWeightDistribution( self.prior_pi, self.prior_sigma_1, self.prior_sigma_2, dist=self.prior_dist) self.bias_prior_dist = PriorWeightDistribution(self.prior_pi, self.prior_sigma_1, self.prior_sigma_2, dist=self.prior_dist) self.init_sharpen_parameters() self.log_prior = 0 self.log_variational_posterior = 0 def sample_weights(self): #sample weights weight_ih = self.weight_ih_sampler.sample() weight_hh = self.weight_hh_sampler.sample() #if use bias, we sample it, otherwise, we are using zeros if self.use_bias: b = self.bias_sampler.sample() b_log_posterior = self.bias_sampler.log_posterior() b_log_prior = self.bias_prior_dist.log_prior(b) else: b = 0 b_log_posterior = 0 b_log_prior = 0 bias = b #gather weights variational posterior and prior likelihoods self.log_variational_posterior = self.weight_hh_sampler.log_posterior( ) + b_log_posterior + self.weight_ih_sampler.log_posterior() self.log_prior = self.weight_ih_prior_dist.log_prior( weight_ih) + b_log_prior + self.weight_hh_prior_dist.log_prior( weight_hh) self.ff_parameters = [weight_ih, weight_hh, bias] return weight_ih, weight_hh, bias def get_frozen_weights(self): #get all deterministic weights weight_ih = self.weight_ih_mu weight_hh = self.weight_hh_mu if self.use_bias: bias = self.bias_mu else: bias = 0 return weight_ih, weight_hh, bias def forward_(self, x, hidden_states, sharpen_loss): if self.loss_to_sharpen is not None: sharpen_loss = self.loss_to_sharpen weight_ih, weight_hh, bias = self.sharpen_posterior( loss=sharpen_loss, input_shape=x.shape) elif (sharpen_loss is not None): sharpen_loss = sharpen_loss weight_ih, weight_hh, bias = self.sharpen_posterior( loss=sharpen_loss, input_shape=x.shape) else: weight_ih, weight_hh, bias = self.sample_weights() #Assumes x is of shape (batch, sequence, feature) bs, seq_sz, _ = x.size() hidden_seq = [] #if no hidden state, we are using zeros if hidden_states is None: h_t = torch.zeros(bs, self.out_features).to(x.device) else: h_t = hidden_states #simplifying our out features, and hidden seq list HS = self.out_features hidden_seq = [] for t in range(seq_sz): x_t = x[:, t, :] # batch the computations into a single matrix multiplication A_t = x_t @ weight_ih[:, :HS * 2] + h_t @ weight_hh[:, :HS * 2] + bias[:HS * 2] r_t, z_t = (torch.sigmoid(A_t[:, :HS]), torch.sigmoid(A_t[:, HS:HS * 2])) n_t = torch.tanh( x_t @ weight_ih[:, HS * 2:HS * 3] + bias[HS * 2:HS * 3] + r_t * (h_t @ weight_hh[:, HS * 3:HS * 4] + bias[HS * 3:HS * 4])) h_t = (1 - z_t) * n_t + z_t * h_t hidden_seq.append(h_t.unsqueeze(0)) hidden_seq = torch.cat(hidden_seq, dim=0) # reshape from shape (sequence, batch, feature) to (batch, sequence, feature) hidden_seq = hidden_seq.transpose(0, 1).contiguous() return hidden_seq, h_t def forward_frozen(self, x, hidden_states): weight_ih, weight_hh, bias = self.get_frozen_weights() #Assumes x is of shape (batch, sequence, feature) bs, seq_sz, _ = x.size() hidden_seq = [] #if no hidden state, we are using zeros if hidden_states is None: h_t = torch.zeros(bs, self.out_features).to(x.device) else: h_t = hidden_states #simplifying our out features, and hidden seq list HS = self.out_features hidden_seq = [] for t in range(seq_sz): x_t = x[:, t, :] # batch the computations into a single matrix multiplication A_t = x_t @ weight_ih[:, :HS * 2] + h_t @ weight_hh[:, :HS * 2] + bias[:HS * 2] r_t, z_t = (torch.sigmoid(A_t[:, :HS]), torch.sigmoid(A_t[:, HS:HS * 2])) n_t = torch.tanh( x_t @ weight_ih[:, HS * 2:HS * 3] + bias[HS * 2:HS * 3] + r_t * (h_t @ weight_hh[:, HS * 3:HS * 4] + bias[HS * 3:HS * 4])) h_t = (1 - z_t) * n_t + z_t * h_t hidden_seq.append(h_t.unsqueeze(0)) hidden_seq = torch.cat(hidden_seq, dim=0) # reshape from shape (sequence, batch, feature) to (batch, sequence, feature) hidden_seq = hidden_seq.transpose(0, 1).contiguous() return hidden_seq, h_t def forward(self, x, hidden_states=None, sharpen_loss=None): if self.freeze: return self.forward_frozen(x, hidden_states) if not self.sharpen: sharpen_loss = None return self.forward_(x, hidden_states, sharpen_loss)
def __init__(self, in_features, out_features, bias=True, prior_sigma_1=0.1, prior_sigma_2=0.002, prior_pi=1, posterior_mu_init=0, posterior_rho_init=-6.0, freeze=False, prior_dist=None, **kwargs): super().__init__(**kwargs) self.in_features = in_features self.out_features = out_features self.use_bias = bias self.freeze = freeze self.posterior_mu_init = posterior_mu_init self.posterior_rho_init = posterior_rho_init self.prior_sigma_1 = prior_sigma_1 self.prior_sigma_2 = prior_sigma_2 self.prior_pi = prior_pi self.prior_dist = prior_dist # Variational weight parameters and sample for weight ih self.weight_ih_mu = nn.Parameter( torch.Tensor(in_features, out_features * 4).normal_(posterior_mu_init, 0.1)) self.weight_ih_rho = nn.Parameter( torch.Tensor(in_features, out_features * 4).normal_(posterior_rho_init, 0.1)) self.weight_ih_sampler = TrainableRandomDistribution( self.weight_ih_mu, self.weight_ih_rho) self.weight_ih = None # Variational weight parameters and sample for weight hh self.weight_hh_mu = nn.Parameter( torch.Tensor(out_features, out_features * 4).normal_(posterior_mu_init, 0.1)) self.weight_hh_rho = nn.Parameter( torch.Tensor(out_features, out_features * 4).normal_(posterior_rho_init, 0.1)) self.weight_hh_sampler = TrainableRandomDistribution( self.weight_hh_mu, self.weight_hh_rho) self.weight_hh = None # Variational weight parameters and sample for bias self.bias_mu = nn.Parameter( torch.Tensor(out_features * 4).normal_(posterior_mu_init, 0.1)) self.bias_rho = nn.Parameter( torch.Tensor(out_features * 4).normal_(posterior_rho_init, 0.1)) self.bias_sampler = TrainableRandomDistribution( self.bias_mu, self.bias_rho) self.bias = None #our prior distributions self.weight_ih_prior_dist = PriorWeightDistribution( self.prior_pi, self.prior_sigma_1, self.prior_sigma_2, dist=self.prior_dist) self.weight_hh_prior_dist = PriorWeightDistribution( self.prior_pi, self.prior_sigma_1, self.prior_sigma_2, dist=self.prior_dist) self.bias_prior_dist = PriorWeightDistribution(self.prior_pi, self.prior_sigma_1, self.prior_sigma_2, dist=self.prior_dist) self.init_sharpen_parameters() self.log_prior = 0 self.log_variational_posterior = 0
class BayesianEmbedding(BayesianModule): """ Bayesian Embedding layer, implements the embedding layer proposed on Weight Uncertainity on Neural Networks (Bayes by Backprop paper). Its objective is be interactable with torch nn.Module API, being able even to be chained in nn.Sequential models with other non-this-lib layers parameters: num_embedding int -> Size of the vocabulary embedding_dim int -> Dimension of the embedding prior_sigma_1 float -> sigma of one of the prior w distributions to mixture prior_sigma_2 float -> sigma of one of the prior w distributions to mixture prior_pi float -> factor to scale the gaussian mixture of the model prior distribution freeze -> wheter the model is instaced as frozen (will use deterministic weights on the feedforward op) padding_idx int -> If given, pads the output with the embedding vector at padding_idx (initialized to zeros) whenever it encounters the index max_norm float -> If given, each embedding vector with norm larger than max_norm is renormalized to have norm max_norm. norm_type float -> The p of the p-norm to compute for the max_norm option. Default 2. scale_grad_by_freq -> If given, this will scale gradients by the inverse of frequency of the words in the mini-batch. Default False. sparse bool -> If True, gradient w.r.t. weight matrix will be a sparse tensor. See Notes for more details regarding sparse gradients. posterior_mu_init float -> posterior mean for the weight mu init posterior_rho_init float -> posterior mean for the weight rho init """ def __init__(self, num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2., scale_grad_by_freq=False, sparse=False, prior_sigma_1=0.1, prior_sigma_2=0.002, prior_pi=1, posterior_mu_init=0, posterior_rho_init=-6.0, freeze=False, prior_dist=None): super().__init__() self.freeze = freeze #parameters for the scale mixture prior self.prior_sigma_1 = prior_sigma_1 self.prior_sigma_2 = prior_sigma_2 self.posterior_mu_init = posterior_mu_init self.posterior_rho_init = posterior_rho_init self.prior_pi = prior_pi self.prior_dist = prior_dist self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim self.padding_idx = padding_idx self.max_norm = max_norm self.norm_type = norm_type self.scale_grad_by_freq = scale_grad_by_freq self.sparse = sparse # Variational weight parameters and sample self.weight_mu = nn.Parameter( torch.Tensor(num_embeddings, embedding_dim).normal_(posterior_mu_init, 0.1)) self.weight_rho = nn.Parameter( torch.Tensor(num_embeddings, embedding_dim).normal_(posterior_rho_init, 0.1)) self.weight_sampler = TrainableRandomDistribution( self.weight_mu, self.weight_rho) # Priors (as BBP paper) self.weight_prior_dist = PriorWeightDistribution(self.prior_pi, self.prior_sigma_1, self.prior_sigma_2, dist=self.prior_dist) self.log_prior = 0 self.log_variational_posterior = 0 def forward(self, x): # Sample the weights and forward it #if the model is frozen, return frozen if self.freeze: return self.forward_frozen(x) w = self.weight_sampler.sample() # Get the complexity cost self.log_variational_posterior = self.weight_sampler.log_posterior() self.log_prior = self.weight_prior_dist.log_prior(w) return F.embedding(x, w, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse) def forward_frozen(self, x): return F.embedding(x, self.weight_mu, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse)
class BayesianConv1d(BayesianModule): # Implements Bayesian Conv2d layer, by drawing them using Weight Uncertanity on Neural Networks algorithm """ Bayesian Linear layer, implements a Convolution 1D layer as proposed on Weight Uncertainity on Neural Networks (Bayes by Backprop paper). Its objective is be interactable with torch nn.Module API, being able even to be chained in nn.Sequential models with other non-this-lib layers parameters: in_channels: int -> incoming channels for the layer out_channels: int -> output channels for the layer kernel_size : tuple (int, int) -> size of the kernels for this convolution layer groups : int -> number of groups on which the convolutions will happend padding : int -> size of padding (0 if no padding) dilation int -> dilation of the weights applied on the input tensor bias: bool -> whether the bias will exist (True) or set to zero (False) prior_sigma_1: float -> prior sigma on the mixture prior distribution 1 prior_sigma_2: float -> prior sigma on the mixture prior distribution 2 prior_pi: float -> pi on the scaled mixture prior posterior_mu_init float -> posterior mean for the weight mu init posterior_rho_init float -> posterior mean for the weight rho init freeze: bool -> wheter the model will start with frozen(deterministic) weights, or not """ def __init__(self, in_channels, out_channels, kernel_size, groups = 1, stride = 1, padding = 0, dilation = 1, bias=True, prior_sigma_1 = 0.1, prior_sigma_2 = 0.002, prior_pi = 1, posterior_mu_init = 0, posterior_rho_init = -7.0, freeze = False, prior_dist = None): super().__init__() #our main parameters self.in_channels = in_channels self.out_channels = out_channels self.freeze = freeze self.kernel_size = kernel_size self.groups = groups self.stride = stride self.padding = padding self.dilation = dilation self.bias = bias self.posterior_mu_init = posterior_mu_init self.posterior_rho_init = posterior_rho_init #parameters for the scale mixture prior self.prior_sigma_1 = prior_sigma_1 self.prior_sigma_2 = prior_sigma_2 self.prior_pi = prior_pi self.prior_dist = prior_dist #our weights self.weight_mu = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, kernel_size).normal_(posterior_mu_init, 0.1)) self.weight_rho = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, kernel_size).normal_(posterior_rho_init, 0.1)) self.weight_sampler = TrainableRandomDistribution(self.weight_mu, self.weight_rho) #our biases self.bias_mu = nn.Parameter(torch.Tensor(out_channels).normal_(posterior_mu_init, 0.1)) self.bias_rho = nn.Parameter(torch.Tensor(out_channels).normal_(posterior_rho_init, 0.1)) self.bias_sampler = TrainableRandomDistribution(self.bias_mu, self.bias_rho) # Priors (as BBP paper) self.weight_prior_dist = PriorWeightDistribution(self.prior_pi, self.prior_sigma_1, self.prior_sigma_2, dist = self.prior_dist) self.bias_prior_dist = PriorWeightDistribution(self.prior_pi, self.prior_sigma_1, self.prior_sigma_2, dist = self.prior_dist) self.log_prior = 0 self.log_variational_posterior = 0 def forward(self, x): #Forward with uncertain weights, fills bias with zeros if layer has no bias #Also calculates the complecity cost for this sampling if self.freeze: return self.forward_frozen(x) w = self.weight_sampler.sample() if self.bias: b = self.bias_sampler.sample() b_log_posterior = self.bias_sampler.log_posterior() b_log_prior = self.bias_prior_dist.log_prior(b) else: b = torch.zeros((self.out_channels)) b_log_posterior = 0 b_log_prior = 0 self.log_variational_posterior = self.weight_sampler.log_posterior() + b_log_posterior self.log_prior = self.weight_prior_dist.log_prior(w) + b_log_prior return F.conv1d(input=x, weight=w, bias=b, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups) def forward_frozen(self, x): # Computes the feedforward operation with the expected value for weight and biases (frozen-like) if self.bias: bias = self.bias_mu assert bias is self.bias_mu, "The bias inputed should be this layer parameter, not a clone." else: bias = torch.zeros(self.out_channels) return F.conv1d(input=x, weight=self.weight_mu, bias=bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups)
class BayesianLinear(BayesianModule): """ Bayesian Linear layer, implements the linear layer proposed on Weight Uncertainity on Neural Networks (Bayes by Backprop paper). Its objective is be interactable with torch nn.Module API, being able even to be chained in nn.Sequential models with other non-this-lib layers parameters: in_fetaures: int -> incoming features for the layer out_features: int -> output features for the layer bias: bool -> whether the bias will exist (True) or set to zero (False) prior_sigma_1: float -> prior sigma on the mixture prior distribution 1 prior_sigma_2: float -> prior sigma on the mixture prior distribution 2 prior_pi: float -> pi on the scaled mixture prior posterior_mu_init float -> posterior mean for the weight mu init posterior_rho_init float -> posterior mean for the weight rho init freeze: bool -> wheter the model will start with frozen(deterministic) weights, or not """ def __init__(self, in_features, out_features, bias=True, prior_sigma_1=0.1, prior_sigma_2=0.4, prior_pi=1, posterior_mu_init=0, posterior_rho_init=-7.0, freeze=False, prior_dist=None): super().__init__() #our main parameters self.in_features = in_features self.out_features = out_features self.bias = bias self.freeze = freeze self.posterior_mu_init = posterior_mu_init self.posterior_rho_init = posterior_rho_init #parameters for the scale mixture prior self.prior_sigma_1 = prior_sigma_1 self.prior_sigma_2 = prior_sigma_2 self.prior_pi = prior_pi self.prior_dist = prior_dist # Variational weight parameters and sample self.weight_mu = nn.Parameter( torch.Tensor(out_features, in_features).normal_(posterior_mu_init, 0.1)) self.weight_rho = nn.Parameter( torch.Tensor(out_features, in_features).normal_(posterior_rho_init, 0.1)) self.weight_sampler = TrainableRandomDistribution( self.weight_mu, self.weight_rho) # Variational bias parameters and sample self.bias_mu = nn.Parameter( torch.Tensor(out_features).normal_(posterior_mu_init, 0.1)) self.bias_rho = nn.Parameter( torch.Tensor(out_features).normal_(posterior_rho_init, 0.1)) self.bias_sampler = TrainableRandomDistribution( self.bias_mu, self.bias_rho) # Priors (as BBP paper) self.weight_prior_dist = PriorWeightDistribution(self.prior_pi, self.prior_sigma_1, self.prior_sigma_2, dist=self.prior_dist) self.bias_prior_dist = PriorWeightDistribution(self.prior_pi, self.prior_sigma_1, self.prior_sigma_2, dist=self.prior_dist) self.log_prior = 0 self.log_variational_posterior = 0 def forward(self, x): # Sample the weights and forward it #if the model is frozen, return frozen if self.freeze: return self.forward_frozen(x) w = self.weight_sampler.sample() if self.bias: b = self.bias_sampler.sample() b_log_posterior = self.bias_sampler.log_posterior() b_log_prior = self.bias_prior_dist.log_prior(b) else: b = torch.zeros((self.out_features)) b_log_posterior = 0 b_log_prior = 0 # Get the complexity cost self.log_variational_posterior = self.weight_sampler.log_posterior( ) + b_log_posterior self.log_prior = self.weight_prior_dist.log_prior(w) + b_log_prior return F.linear(x, w, b) def forward_frozen(self, x): """ Computes the feedforward operation with the expected value for weight and biases """ if self.bias: return F.linear(x, self.weight_mu, self.bias_mu) else: return F.linear(x, self.weight_mu, torch.zeros(self.out_features))