def test_scale_mixture_any_prior(self):
        mu = torch.Tensor(10, 10).uniform_(-1, 1)
        rho = torch.Tensor(10, 10).uniform_(-1, 1)

        dist = GaussianVariational(mu, rho)
        s1 = dist.sample()

        log_posterior = dist.log_posterior()

        prior_dist = ScaleMixturePrior(dist=torch.distributions.studentT.StudentT(1, 1))
        log_prior = prior_dist.log_prior(s1)

        #print(log_prior)
        #print(log_posterior)
        self.assertEqual(log_prior == log_prior, torch.tensor(True))
        self.assertEqual(log_posterior <= log_posterior - log_prior, torch.tensor(True))
        pass
Beispiel #2
0
class BayesianLSTM(BayesianModule):
    """
    Bayesian LSTM layer, implements the linear layer proposed on Weight Uncertainity on Neural Networks
    (Bayes by Backprop paper).

    Its objective is be interactable with torch nn.Module API, being able even to be chained in nn.Sequential models with other non-this-lib layers
    
    parameters:
        in_fetaures: int -> incoming features for the layer
        out_features: int -> output features for the layer
        bias: bool -> whether the bias will exist (True) or set to zero (False)
        prior_sigma_1: float -> prior sigma on the mixture prior distribution 1
        prior_sigma_2: float -> prior sigma on the mixture prior distribution 2
        prior_pi: float -> pi on the scaled mixture prior
        freeze: bool -> wheter the model will start with frozen(deterministic) weights, or not
    
    """
    def __init__(self,
                 in_features,
                 out_features,
                 bias=True,
                 prior_sigma_1=1,
                 prior_sigma_2=0.002,
                 prior_pi=0.5,
                 freeze=False):

        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.use_bias = bias
        self.freeze = freeze

        self.prior_sigma_1 = prior_sigma_1
        self.prior_sigma_2 = prior_sigma_2
        self.prior_pi = prior_pi

        # Variational weight parameters and sample for weight ih
        self.weight_ih_mu = nn.Parameter(
            torch.Tensor(in_features, out_features * 4).uniform_(-0.2, 0.2))
        self.weight_ih_rho = nn.Parameter(
            torch.Tensor(in_features, out_features * 4).uniform_(-5, -4))
        self.weight_ih_sampler = GaussianVariational(self.weight_ih_mu,
                                                     self.weight_ih_rho)
        self.weight_ih = None

        # Variational weight parameters and sample for weight hh
        self.weight_hh_mu = nn.Parameter(
            torch.Tensor(out_features, out_features * 4).uniform_(-0.2, 0.2))
        self.weight_hh_rho = nn.Parameter(
            torch.Tensor(out_features, out_features * 4).uniform_(-5, -4))
        self.weight_hh_sampler = GaussianVariational(self.weight_hh_mu,
                                                     self.weight_hh_rho)
        self.weight_hh = None

        # Variational weight parameters and sample for bias
        self.bias_mu = nn.Parameter(
            torch.Tensor(out_features * 4).uniform_(-0.2, 0.2))
        self.bias_rho = nn.Parameter(
            torch.Tensor(out_features * 4).uniform_(-5, -4))
        self.bias_sampler = GaussianVariational(self.bias_mu, self.bias_rho)
        self.bias = None

        #our prior distributions
        self.weight_ih_prior_dist = ScaleMixturePrior(self.prior_pi,
                                                      self.prior_sigma_1,
                                                      self.prior_sigma_2)
        self.weight_hh_prior_dist = ScaleMixturePrior(self.prior_pi,
                                                      self.prior_sigma_1,
                                                      self.prior_sigma_2)
        self.bias_prior_dist = ScaleMixturePrior(self.prior_pi,
                                                 self.prior_sigma_1,
                                                 self.prior_sigma_2)

        self.log_prior = 0
        self.log_variational_posterior = 0

    def sample_weights(self):
        #sample weights
        self.weight_ih = self.weight_ih_sampler.sample()
        self.weight_hh = self.weight_hh_sampler.sample()

        #if use bias, we sample it, otherwise, we are using zeros
        if self.use_bias:
            b = self.bias_sampler.sample()
            b_log_posterior = self.bias_sampler.log_posterior()
            b_log_prior = self.bias_prior_dist.log_prior(b)

        else:
            b = torch.zeros((self.out_features * 4))
            b_log_posterior = 0
            b_log_prior = 0

        self.bias = b

        #gather weights variational posterior and prior likelihoods
        self.log_variational_posterior = self.weight_hh_sampler.log_posterior(
        ) + b_log_posterior + self.weight_ih_sampler.log_posterior()

        self.log_prior = self.weight_ih_prior_dist.log_prior(
            self.weight_ih
        ) + b_log_prior + self.weight_hh_prior_dist.log_prior(self.weight_hh)

    def get_frozen_weights(self):

        #get all deterministic weights
        self.weight_ih = self.weight_ih_mu
        self.weight_hh = self.weight_hh_mu
        if self.use_bias:
            self.bias = self.bias_mu
        else:
            self.bias = torch.zeros((self.out_features * 4))

    def forward(self, x, hidden_states=None):

        #Assumes x is of shape (batch, sequence, feature)
        bs, seq_sz, _ = x.size()
        hidden_seq = []

        #if no hidden state, we are using zeros
        if hidden_states is None:
            h_t, c_t = (torch.zeros(self.out_features).to(x.device),
                        torch.zeros(self.out_features).to(x.device))
        else:
            h_t, c_t = hidden_states

        if self.freeze:
            self.get_frozen_weights()
        else:
            self.sample_weights()

        #simplifying our out features, and hidden seq list
        HS = self.out_features
        hidden_seq = []

        for t in range(seq_sz):
            x_t = x[:, t, :]
            # batch the computations into a single matrix multiplication
            gates = x_t @ self.weight_ih + h_t @ self.weight_hh + self.bias

            i_t, f_t, g_t, o_t = (
                torch.sigmoid(gates[:, :HS]),  # input
                torch.sigmoid(gates[:, HS:HS * 2]),  # forget
                torch.tanh(gates[:, HS * 2:HS * 3]),
                torch.sigmoid(gates[:, HS * 3:]),  # output
            )
            c_t = f_t * c_t + i_t * g_t
            h_t = o_t * torch.tanh(c_t)
            hidden_seq.append(h_t.unsqueeze(0))

        hidden_seq = torch.cat(hidden_seq, dim=0)
        # reshape from shape (sequence, batch, feature) to (batch, sequence, feature)
        hidden_seq = hidden_seq.transpose(0, 1).contiguous()

        return hidden_seq, (h_t, c_t)
class BayesianLinear(BayesianModule):
    """
    Bayesian Linear layer, implements the linear layer proposed on Weight Uncertainity on Neural Networks
    (Bayes by Backprop paper).

    Its objective is be interactable with torch nn.Module API, being able even to be chained in nn.Sequential models with other non-this-lib layers
    
    parameters:
        in_fetaures: int -> incoming features for the layer
        out_features: int -> output features for the layer
        bias: bool -> whether the bias will exist (True) or set to zero (False)
        prior_sigma_1: float -> prior sigma on the mixture prior distribution 1
        prior_sigma_2: float -> prior sigma on the mixture prior distribution 2
        prior_pi: float -> pi on the scaled mixture prior
        freeze: bool -> wheter the model will start with frozen(deterministic) weights, or not
    
    """
    def __init__(self,
                 in_features,
                 out_features,
                 bias=True,
                 prior_sigma_1=1,
                 prior_sigma_2=0.002,
                 prior_pi=0.5,
                 freeze=False):
        super().__init__()

        #our main parameters
        self.in_features = in_features
        self.out_features = out_features
        self.bias = bias
        self.freeze = freeze

        #parameters for the scale mixture prior
        self.prior_sigma_1 = prior_sigma_1
        self.prior_sigma_2 = prior_sigma_2
        self.prior_pi = prior_pi

        # Variational weight parameters and sample
        self.weight_mu = nn.Parameter(
            torch.Tensor(out_features, in_features).uniform_(-0.2, 0.2))
        self.weight_rho = nn.Parameter(
            torch.Tensor(out_features, in_features).uniform_(-5, -4))
        self.weight_sampler = GaussianVariational(self.weight_mu,
                                                  self.weight_rho)

        # Variational bias parameters and sample
        self.bias_mu = nn.Parameter(
            torch.Tensor(out_features).uniform_(-0.2, 0.2))
        self.bias_rho = nn.Parameter(
            torch.Tensor(out_features).uniform_(-5, -4))
        self.bias_sampler = GaussianVariational(self.bias_mu, self.bias_rho)

        # Priors (as BBP paper)
        self.weight_prior_dist = ScaleMixturePrior(self.prior_pi,
                                                   self.prior_sigma_1,
                                                   self.prior_sigma_2)
        self.bias_prior_dist = ScaleMixturePrior(self.prior_pi,
                                                 self.prior_sigma_1,
                                                 self.prior_sigma_2)
        self.log_prior = 0
        self.log_variational_posterior = 0

    def forward(self, x):
        # Sample the weights and forward it

        #if the model is frozen, return frozen
        if self.freeze:
            return self.forward_frozen(x)

        w = self.weight_sampler.sample()

        if self.bias:
            b = self.bias_sampler.sample()
            b_log_posterior = self.bias_sampler.log_posterior()
            b_log_prior = self.bias_prior_dist.log_prior(b)

        else:
            b = torch.zeros((self.out_features))
            b_log_posterior = 0
            b_log_prior = 0

        # Get the complexity cost
        self.log_variational_posterior = self.weight_sampler.log_posterior(
        ) + b_log_posterior
        self.log_prior = self.weight_prior_dist.log_prior(w) + b_log_prior

        return F.linear(x, w, b)

    def forward_frozen(self, x):
        """
        Computes the feedforward operation with the expected value for weight and biases
        """
        if self.bias:
            return F.linear(x, self.weight_mu, self.bias_mu)
        else:
            return F.linear(x, self.weight_mu, torch.zeros(self.out_features))
Beispiel #4
0
class BayesianGRU(BayesianModule):
    """
    Bayesian GRU layer, implements the linear layer proposed on Weight Uncertainity on Neural Networks
    (Bayes by Backprop paper).

    Its objective is be interactable with torch nn.Module API, being able even to be chained in nn.Sequential models with other non-this-lib layers
    
    parameters:
        in_fetaures: int -> incoming features for the layer
        out_features: int -> output features for the layer
        bias: bool -> whether the bias will exist (True) or set to zero (False)
        prior_sigma_1: float -> prior sigma on the mixture prior distribution 1
        prior_sigma_2: float -> prior sigma on the mixture prior distribution 2
        prior_pi: float -> pi on the scaled mixture prior
        posterior_mu_init float -> posterior mean for the weight mu init
        posterior_rho_init float -> posterior mean for the weight rho init
        freeze: bool -> wheter the model will start with frozen(deterministic) weights, or not
    
    """
    def __init__(self,
                 in_features,
                 out_features,
                 bias=True,
                 prior_sigma_1=0.1,
                 prior_sigma_2=0.002,
                 prior_pi=1,
                 posterior_mu_init=0,
                 posterior_rho_init=-6.0,
                 freeze=False,
                 prior_dist=None):

        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.use_bias = bias
        self.freeze = freeze

        self.posterior_mu_init = posterior_mu_init
        self.posterior_rho_init = posterior_rho_init

        self.prior_sigma_1 = prior_sigma_1
        self.prior_sigma_2 = prior_sigma_2
        self.prior_pi = prior_pi
        self.prior_dist = prior_dist

        # Variational weight parameters and sample for weight ih
        self.weight_ih_mu = nn.Parameter(
            torch.Tensor(in_features,
                         out_features * 4).normal_(posterior_mu_init, 0.1))
        self.weight_ih_rho = nn.Parameter(
            torch.Tensor(in_features,
                         out_features * 4).normal_(posterior_rho_init, 0.1))
        self.weight_ih_sampler = GaussianVariational(self.weight_ih_mu,
                                                     self.weight_ih_rho)
        self.weight_ih = None

        # Variational weight parameters and sample for weight hh
        self.weight_hh_mu = nn.Parameter(
            torch.Tensor(out_features,
                         out_features * 4).normal_(posterior_mu_init, 0.1))
        self.weight_hh_rho = nn.Parameter(
            torch.Tensor(out_features,
                         out_features * 4).normal_(posterior_rho_init, 0.1))
        self.weight_hh_sampler = GaussianVariational(self.weight_hh_mu,
                                                     self.weight_hh_rho)
        self.weight_hh = None

        # Variational weight parameters and sample for bias
        self.bias_mu = nn.Parameter(
            torch.Tensor(out_features * 4).normal_(posterior_mu_init, 0.1))
        self.bias_rho = nn.Parameter(
            torch.Tensor(out_features * 4).normal_(posterior_rho_init, 0.1))
        self.bias_sampler = GaussianVariational(self.bias_mu, self.bias_rho)
        self.bias = None

        #our prior distributions
        self.weight_ih_prior_dist = ScaleMixturePrior(self.prior_pi,
                                                      self.prior_sigma_1,
                                                      self.prior_sigma_2,
                                                      dist=self.prior_dist)
        self.weight_hh_prior_dist = ScaleMixturePrior(self.prior_pi,
                                                      self.prior_sigma_1,
                                                      self.prior_sigma_2,
                                                      dist=self.prior_dist)
        self.bias_prior_dist = ScaleMixturePrior(self.prior_pi,
                                                 self.prior_sigma_1,
                                                 self.prior_sigma_2,
                                                 dist=self.prior_dist)

        self.log_prior = 0
        self.log_variational_posterior = 0

    def sample_weights(self):
        #sample weights
        weight_ih = self.weight_ih_sampler.sample()
        weight_hh = self.weight_hh_sampler.sample()

        #if use bias, we sample it, otherwise, we are using zeros
        if self.use_bias:
            b = self.bias_sampler.sample()
            b_log_posterior = self.bias_sampler.log_posterior()
            b_log_prior = self.bias_prior_dist.log_prior(b)

        else:
            b = 0
            b_log_posterior = 0
            b_log_prior = 0

        bias = b

        #gather weights variational posterior and prior likelihoods
        self.log_variational_posterior = self.weight_hh_sampler.log_posterior(
        ) + b_log_posterior + self.weight_ih_sampler.log_posterior()

        self.log_prior = self.weight_ih_prior_dist.log_prior(
            weight_ih) + b_log_prior + self.weight_hh_prior_dist.log_prior(
                weight_hh)
        return weight_ih, weight_hh, bias

    def get_frozen_weights(self):

        #get all deterministic weights
        weight_ih = self.weight_ih_mu
        weight_hh = self.weight_hh_mu
        if self.use_bias:
            bias = self.bias_mu
        else:
            bias = 0

        return weight_ih, weight_hh, bias

    def forward_(self, x, hidden_states):

        weight_ih, weight_hh, bias = self.sample_weights()

        #Assumes x is of shape (batch, sequence, feature)
        bs, seq_sz, _ = x.size()
        hidden_seq = []

        #if no hidden state, we are using zeros
        if hidden_states is None:
            h_t = torch.zeros(bs, self.out_features).to(x.device)
        else:
            h_t = hidden_states

        #simplifying our out features, and hidden seq list
        HS = self.out_features
        hidden_seq = []

        for t in range(seq_sz):
            x_t = x[:, t, :]
            # batch the computations into a single matrix multiplication
            A_t = x_t @ weight_ih[:, :HS * 2] + h_t @ weight_hh[:, :HS *
                                                                2] + bias[:HS *
                                                                          2]

            r_t, z_t = (torch.sigmoid(A_t[:, :HS]),
                        torch.sigmoid(A_t[:, HS:HS * 2]))

            n_t = torch.tanh(
                x_t @ weight_ih[:, HS * 2:HS * 3] + bias[HS * 2:HS * 3] + r_t *
                (h_t @ weight_hh[:, HS * 3:HS * 4] + bias[HS * 3:HS * 4]))
            h_t = (1 - z_t) * n_t + z_t * h_t

            hidden_seq.append(h_t.unsqueeze(0))

        hidden_seq = torch.cat(hidden_seq, dim=0)
        # reshape from shape (sequence, batch, feature) to (batch, sequence, feature)
        hidden_seq = hidden_seq.transpose(0, 1).contiguous()

        return hidden_seq, h_t

    def forward_frozen(self, x, hidden_states):

        weight_ih, weight_hh, bias = self.get_frozen_weights()

        #Assumes x is of shape (batch, sequence, feature)
        bs, seq_sz, _ = x.size()
        hidden_seq = []

        #if no hidden state, we are using zeros
        if hidden_states is None:
            h_t = torch.zeros(bs, self.out_features).to(x.device)
        else:
            h_t = hidden_states

        #simplifying our out features, and hidden seq list
        HS = self.out_features
        hidden_seq = []

        for t in range(seq_sz):
            x_t = x[:, t, :]
            # batch the computations into a single matrix multiplication
            A_t = x_t @ weight_ih[:, :HS * 2] + h_t @ weight_hh[:, :HS *
                                                                2] + bias[:HS *
                                                                          2]

            r_t, z_t = (torch.sigmoid(A_t[:, :HS]),
                        torch.sigmoid(A_t[:, HS:HS * 2]))

            n_t = torch.tanh(
                x_t @ weight_ih[:, HS * 2:HS * 3] + bias[HS * 2:HS * 3] + r_t *
                (h_t @ weight_hh[:, HS * 3:HS * 4] + bias[HS * 3:HS * 4]))
            h_t = (1 - z_t) * n_t + z_t * h_t

            hidden_seq.append(h_t.unsqueeze(0))

        hidden_seq = torch.cat(hidden_seq, dim=0)
        # reshape from shape (sequence, batch, feature) to (batch, sequence, feature)
        hidden_seq = hidden_seq.transpose(0, 1).contiguous()

        return hidden_seq, h_t

    def forward(self, x, hidden_states=None):

        if self.freeze:
            return self.forward_frozen(x, hidden_states)

        return self.forward_(x, hidden_states)
class BayesianEmbedding(BayesianModule):
    """
    Bayesian Embedding layer, implements the embedding layer proposed on Weight Uncertainity on Neural Networks
    (Bayes by Backprop paper).

    Its objective is be interactable with torch nn.Module API, being able even to be chained in nn.Sequential models with other non-this-lib layers
    
    parameters:
        num_embedding int -> Size of the vocabulary
        embedding_dim int -> Dimension of the embedding
        prior_sigma_1 float -> sigma of one of the prior w distributions to mixture
        prior_sigma_2 float -> sigma of one of the prior w distributions to mixture
        prior_pi float -> factor to scale the gaussian mixture of the model prior distribution
        freeze -> wheter the model is instaced as frozen (will use deterministic weights on the feedforward op)
        padding_idx int -> If given, pads the output with the embedding vector at padding_idx (initialized to zeros) whenever it encounters the index
        max_norm float -> If given, each embedding vector with norm larger than max_norm is renormalized to have norm max_norm.
        norm_type float -> The p of the p-norm to compute for the max_norm option. Default 2.
        scale_grad_by_freq -> If given, this will scale gradients by the inverse of frequency of the words in the mini-batch. Default False.
        sparse bool -> If True, gradient w.r.t. weight matrix will be a sparse tensor. See Notes for more details regarding sparse gradients.
        posterior_mu_init float -> posterior mean for the weight mu init
        posterior_rho_init float -> posterior mean for the weight rho init

    
    """
    def __init__(self,
                 num_embeddings,
                 embedding_dim,
                 padding_idx=None,
                 max_norm=None,
                 norm_type=2.,
                 scale_grad_by_freq=False,
                 sparse=False,
                 prior_sigma_1=0.1,
                 prior_sigma_2=0.002,
                 prior_pi=1,
                 posterior_mu_init=0,
                 posterior_rho_init=-6.0,
                 freeze=False,
                 prior_dist=None):
        super().__init__()

        self.freeze = freeze

        #parameters for the scale mixture prior
        self.prior_sigma_1 = prior_sigma_1
        self.prior_sigma_2 = prior_sigma_2
        self.posterior_mu_init = posterior_mu_init
        self.posterior_rho_init = posterior_rho_init

        self.prior_pi = prior_pi
        self.prior_dist = prior_dist

        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.padding_idx = padding_idx
        self.max_norm = max_norm
        self.norm_type = norm_type
        self.scale_grad_by_freq = scale_grad_by_freq
        self.sparse = sparse

        # Variational weight parameters and sample
        self.weight_mu = nn.Parameter(
            torch.Tensor(num_embeddings,
                         embedding_dim).normal_(posterior_mu_init, 0.1))
        self.weight_rho = nn.Parameter(
            torch.Tensor(num_embeddings,
                         embedding_dim).normal_(posterior_rho_init, 0.1))
        self.weight_sampler = GaussianVariational(self.weight_mu,
                                                  self.weight_rho)

        # Priors (as BBP paper)
        self.weight_prior_dist = ScaleMixturePrior(self.prior_pi,
                                                   self.prior_sigma_1,
                                                   self.prior_sigma_2,
                                                   dist=self.prior_dist)
        self.log_prior = 0
        self.log_variational_posterior = 0

    def forward(self, x):
        # Sample the weights and forward it

        #if the model is frozen, return frozen
        if self.freeze:
            return self.forward_frozen(x)

        w = self.weight_sampler.sample()

        # Get the complexity cost
        self.log_variational_posterior = self.weight_sampler.log_posterior()
        self.log_prior = self.weight_prior_dist.log_prior(w)

        return F.embedding(x, w, self.padding_idx, self.max_norm,
                           self.norm_type, self.scale_grad_by_freq,
                           self.sparse)

    def forward_frozen(self, x):
        return F.embedding(x, self.weight_mu, self.padding_idx, self.max_norm,
                           self.norm_type, self.scale_grad_by_freq,
                           self.sparse)
Beispiel #6
0
class BayesianConv1d(BayesianModule):

    # Implements Bayesian Conv2d layer, by drawing them using Weight Uncertanity on Neural Networks algorithm
    """
    Bayesian Linear layer, implements a Convolution 1D layer as proposed on Weight Uncertainity on Neural Networks
    (Bayes by Backprop paper).

    Its objective is be interactable with torch nn.Module API, being able even to be chained in nn.Sequential models with other non-this-lib layers

    parameters:
        in_channels: int -> incoming channels for the layer
        out_channels: int -> output channels for the layer
        kernel_size : tuple (int, int) -> size of the kernels for this convolution layer
        groups : int -> number of groups on which the convolutions will happend
        padding : int -> size of padding (0 if no padding)
        dilation int -> dilation of the weights applied on the input tensor


        bias: bool -> whether the bias will exist (True) or set to zero (False)
        prior_sigma_1: float -> prior sigma on the mixture prior distribution 1
        prior_sigma_2: float -> prior sigma on the mixture prior distribution 2
        prior_pi: float -> pi on the scaled mixture prior
        posterior_mu_init float -> posterior mean for the weight mu init
        posterior_rho_init float -> posterior mean for the weight rho init
        freeze: bool -> wheter the model will start with frozen(deterministic) weights, or not
    
    """
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 groups=1,
                 stride=1,
                 padding=0,
                 dilation=1,
                 bias=True,
                 prior_sigma_1=0.1,
                 prior_sigma_2=0.002,
                 prior_pi=1,
                 posterior_mu_init=0,
                 posterior_rho_init=-6.0,
                 freeze=False,
                 prior_dist=None):
        super().__init__()

        #our main parameters
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.freeze = freeze
        self.kernel_size = kernel_size
        self.groups = groups
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.bias = bias

        self.posterior_mu_init = posterior_mu_init
        self.posterior_rho_init = posterior_rho_init

        #parameters for the scale mixture prior
        self.prior_sigma_1 = prior_sigma_1
        self.prior_sigma_2 = prior_sigma_2
        self.prior_pi = prior_pi
        self.prior_dist = prior_dist

        #our weights
        self.weight_mu = nn.Parameter(
            torch.Tensor(out_channels, in_channels // groups,
                         kernel_size).normal_(posterior_mu_init, 0.1))
        self.weight_rho = nn.Parameter(
            torch.Tensor(out_channels, in_channels // groups,
                         kernel_size).normal_(posterior_rho_init, 0.1))
        self.weight_sampler = GaussianVariational(self.weight_mu,
                                                  self.weight_rho)

        #our biases
        self.bias_mu = nn.Parameter(
            torch.Tensor(out_channels).normal_(posterior_mu_init, 0.1))
        self.bias_rho = nn.Parameter(
            torch.Tensor(out_channels).normal_(posterior_rho_init, 0.1))
        self.bias_sampler = GaussianVariational(self.bias_mu, self.bias_rho)

        # Priors (as BBP paper)
        self.weight_prior_dist = ScaleMixturePrior(self.prior_pi,
                                                   self.prior_sigma_1,
                                                   self.prior_sigma_2,
                                                   dist=self.prior_dist)
        self.bias_prior_dist = ScaleMixturePrior(self.prior_pi,
                                                 self.prior_sigma_1,
                                                 self.prior_sigma_2,
                                                 dist=self.prior_dist)
        self.log_prior = 0
        self.log_variational_posterior = 0

    def forward(self, x):
        #Forward with uncertain weights, fills bias with zeros if layer has no bias
        #Also calculates the complecity cost for this sampling
        if self.freeze:
            return self.forward_frozen(x)

        w = self.weight_sampler.sample()

        if self.bias:
            b = self.bias_sampler.sample()
            b_log_posterior = self.bias_sampler.log_posterior()
            b_log_prior = self.bias_prior_dist.log_prior(b)

        else:
            b = torch.zeros((self.out_channels))
            b_log_posterior = 0
            b_log_prior = 0

        self.log_variational_posterior = self.weight_sampler.log_posterior(
        ) + b_log_posterior
        self.log_prior = self.weight_prior_dist.log_prior(w) + b_log_prior

        return F.conv1d(input=x,
                        weight=w,
                        bias=b,
                        stride=self.stride,
                        padding=self.padding,
                        dilation=self.dilation,
                        groups=self.groups)

    def forward_frozen(self, x):
        # Computes the feedforward operation with the expected value for weight and biases (frozen-like)

        if self.bias:
            bias = self.bias_mu
            assert bias is self.bias_mu, "The bias inputed should be this layer parameter, not a clone."
        else:
            bias = torch.zeros(self.out_channels)

        return F.conv1d(input=x,
                        weight=self.weight_mu,
                        bias=bias,
                        stride=self.stride,
                        padding=self.padding,
                        dilation=self.dilation,
                        groups=self.groups)