Пример #1
0
def show_per_pixel_uncertainty_kl(model: V3AE, z_star: Union[None,
                                                             torch.Tensor],
                                  extent: Union[int, float]):
    z_star = torch.Tensor([[-3.5, -3.5]]) if z_star is None else z_star
    z_star = z_star[None, :]
    _, _, α_z, β_z = model.parametrise_z(z_star)
    _, q_λ_z, p_λ = model.sample_precision(α_z, β_z)
    q_α, q_β = q_λ_z.base_dist.concentration.flatten(
    ), q_λ_z.base_dist.rate.flatten()
    p_α, p_β = p_λ.base_dist.concentration.flatten(
    ), p_λ.base_dist.rate.flatten()
    q_λ_z = D.Gamma(q_α, q_β)
    p_λ = D.Gamma(p_α, p_β)
    kl = model.kl(q_λ_z, p_λ)
    var = β_z / (α_z - 1)
    fig, (ax1, ax2) = plt.subplots(1, 2)
    extent = 4
    ax1.imshow(
        var.reshape((28, 28)),
        extent=(-extent, extent, -extent, extent),
        vmax=np.percentile(var.flatten(), 75),
        # vmin=np.percentile(var.flatten(), 2),
        # cmap=cmap,
    )
    cax2 = ax2.imshow(
        kl.reshape((28, 28)),
        extent=(-extent, extent, -extent, extent),
        # vmax=np.percentile(kl.flatten(), 75),
        # vmin=np.percentile(var.flatten(), 2),
        # cmap=cmap,
    )
    ax2.title.set_text(f"Avg kl: {kl.mean():.2f}\nSum kl: {kl.sum():.2f}")
    fig.colorbar(cax2)
    plt.show()
Пример #2
0
def coupled_gibbs_kernel(lamb1, lamb2, beta1, beta2, s, t, alpha, gamma,
                         delta):
    lamb1, lamb2 = maximal_coupling(dist.Gamma(alpha + s, beta1 + t),
                                    dist.Gamma(alpha + s, beta2 + t)).sample()
    beta1, beta2 = maximal_coupling(
        dist.Gamma(gamma + 10 * alpha, delta + lamb1.sum()),
        dist.Gamma(gamma + 10 * alpha, delta + lamb2.sum())).sample()
    return lamb1, lamb2, beta1, beta2
Пример #3
0
 def kl(
     α: torch.Tensor,
     β: torch.Tensor,
     a: torch.Tensor,
     b: torch.Tensor,
     ε: float = 1e-10,
 ) -> torch.Tensor:
     β = β + ε
     qp = D.Gamma(α, β)
     pp = D.Gamma(a, b)
     return D.kl_divergence(qp, pp)
Пример #4
0
def multiplicative_gamma(a1, a2, h):
    """
    Samples as in Eq (2) from [1]
    h is the number of layers
    """
    g1 = dist.Gamma(torch.tensor([a1]),
                    torch.tensor([1.])).sample([1]).squeeze(0)
    g2 = dist.Gamma(torch.tensor([a2]),
                    torch.tensor([1.])).sample([h - 1]).squeeze()
    taus = torch.cat((g1, g1 * torch.cumprod(g2, dim=0)), 0)
    ##
    return taus
Пример #5
0
def additive_gamma(a1, a2, h):
    """
    Samples from additive gamma process
    h is the number of layers
    """
    g1 = dist.Gamma(torch.tensor([a1]),
                    torch.tensor([1.])).sample([1]).squeeze(0)
    g2 = dist.Gamma(torch.tensor([a2]),
                    torch.tensor([1.])).sample([h - 1]).squeeze()
    taus = torch.cat((g1, g1 + torch.cumsum(g2, dim=0)), 0)
    ##
    return taus
Пример #6
0
    def compute_batch_loss(self, x, y, shape, scale):
        # compute likelihood
        likelihood = -torch.sum(torch.log_softmax(y, 1) * x, 1)

        # compute KL divergence
        prior_distribution = distributions.Gamma(self.shape_prior,
                                                 self.scale_prior)
        local_distribution = distributions.Gamma(shape, scale)
        kld = torch.sum(
            distributions.kl_divergence(local_distribution,
                                        prior_distribution), 1)

        return likelihood, kld
Пример #7
0
    def _sample_std(self, shape, rate):
        with torch.no_grad():
            gamma_dist = dist.Gamma(shape, rate)
            inv_var = gamma_dist.rsample()
            std = 1. / (torch.sqrt(inv_var) + self.eps)

            return std
Пример #8
0
    def generate(
        self,
        n_samples: int = 100,
        genes: Union[list, np.ndarray] = None,
        batch_size: int = 64,
        #batch_size: int = 128,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Create observation samples from the Posterior Predictive distribution

        :param n_samples: Number of required samples for each cell
        :param genes: Indices of genes of interest
        :param batch_size: Desired Batch size to generate data

        :return: Tuple (x_new, x_old)
            Where x_old has shape (n_cells, n_genes)
            Where x_new has shape (n_cells, n_genes, n_samples)
        """
        assert self.model.reconstruction_loss in ["zinb", "nb"]
        zero_inflated = self.model.reconstruction_loss == "zinb"
        x_old = []
        x_new = []
        for tensors in self.update({"batch_size": batch_size}):
            sample_batch, _, _, batch_index, labels = tensors
            outputs = self.model.inference(sample_batch,
                                           batch_index=batch_index,
                                           y=labels,
                                           n_samples=n_samples)
            px_r = outputs["px_r"]
            px_rate = outputs["px_rate"]
            px_dropout = outputs["px_dropout"]

            p = px_rate / (px_rate + px_r)
            r = px_r
            # Important remark: Gamma is parametrized by the rate = 1/scale!
            l_train = distributions.Gamma(concentration=r,
                                          rate=(1 - p) / p).sample()
            # Clamping as distributions objects can have buggy behaviors when
            # their parameters are too high
            l_train = torch.clamp(l_train, max=1e8)
            gene_expressions = distributions.Poisson(l_train).sample(
            )  # Shape : (n_samples, n_cells_batch, n_genes)
            if zero_inflated:
                p_zero = (1.0 + torch.exp(-px_dropout)).pow(-1)
                random_prob = torch.rand_like(p_zero)
                gene_expressions[random_prob <= p_zero] = 0

            gene_expressions = gene_expressions.permute(
                [1, 2, 0])  # Shape : (n_cells_batch, n_genes, n_samples)

            x_old.append(sample_batch.cpu())
            x_new.append(gene_expressions.cpu())

        x_old = torch.cat(x_old)  # Shape (n_cells, n_genes)
        x_new = torch.cat(x_new)  # Shape (n_cells, n_genes, n_samples)
        if genes is not None:
            gene_ids = self.gene_dataset.genes_to_index(genes)
            x_new = x_new[:, gene_ids, :]
            x_old = x_old[:, gene_ids]
        return x_new.numpy(), x_old.numpy()
Пример #9
0
    def forward(self, input_, hx, update_noise=True):
        """
        Args:
            input_: A (batch, input_size) tensor containing input
                features.
            hx: A tuple (h_0, c_0), which contains the initial hidden
                and cell state, where the size of both states is
                (batch, hidden_size).
        Returns:
            h_1, c_1: Tensors containing the next hidden and cell state.
        """

        h_0, c_0 = hx
        new_w_hh = self.weight_hh_wdrop \
            if self.weight_hh_wdrop is not None else self.weight_hh
        gates = F.linear(input_, self.weight_ih, self.bias) + F.linear(h_0, new_w_hh)

        alpha, o, g = torch.split(gates, [self.hidden_size * 4, self.hidden_size, self.hidden_size], dim=1)
        alpha = F.softplus(alpha) + self.eps
        if self.training:
            G = to_var(tdist.Gamma(alpha,1.0).rsample()) + self.eps
            G0,G1,G2,G3 = G.chunk(4,1)
        else:
            G0,G1,G2,G3 = alpha.chunk(4,1) # Expectation

        a_i, b_i = G0, G1
        a_f, b_f = G2, G3
        sigm_i = a_i * (1.0 / (a_i + b_i))
        sigm_f = a_f * (1.0 / (a_f + b_f))  # 1 - G2 * (1/(G2+G0))

        c_1 = sigm_f * c_0 + sigm_i * torch.tanh(g)
        h_1 = torch.sigmoid(o) * torch.tanh(c_1)
        gate_value = [a_i, b_i, a_f, b_f]
        return h_1, c_1, gate_value,None
    def getGPmodel(self):

        pyro.set_rng_seed(self.seed)
        pyro.clear_param_store()

        args = self.getInitXY()
        if self.dim1(args) == True:
            X4Kernel = torch.tensor(self.getFlotList(args[0]))
            dim = 1
        else:
            X4Kernel = torch.tensor([self.getFlotList(args[0])])
            dim = X4Kernel.shape[1]
        y4Kernel = torch.tensor(self.getFlotList(args[1]))

        #X4Kernel = torch.tensor([[2.0, 40], [6.0, 40], [6.0, 80], [8.0, 90]])
        #y4Kernel = torch.tensor([1.0, 2.0, 3.0 ,4.0])
        #kernel = gp.kernels.RBF(input_dim=dim)

        kernel = gp.kernels.Matern52(input_dim=dim)
        kernel.set_prior("variance",
                         dist.Uniform(torch.tensor(0.5), torch.tensor(1.5)))
        kernel.set_prior("lengthscale",
                         dist.Gamma(torch.tensor(7.5), torch.tensor(1.)))

        gpmodel = gp.models.GPRegression(X4Kernel,
                                         y4Kernel,
                                         kernel,
                                         noise=torch.tensor(self.noise),
                                         jitter=1.0e-4)  #1.0e-4

        return gpmodel
Пример #11
0
def get_robust_regression(device: torch.device) -> GetterReturnType:
    N = 10
    K = 10

    # X.shape: (N, K + 1), Y.shape: (N, 1)
    X = torch.rand(N, K + 1, device=device)
    Y = torch.rand(N, 1, device=device)

    # Predefined nu_alpha and nu_beta, nu_alpha.shape: (1, 1), nu_beta.shape: (1, 1)
    nu_alpha = torch.rand(1, 1, device=device)
    nu_beta = torch.rand(1, 1, device=device)
    nu = dist.Gamma(nu_alpha, nu_beta)

    # Predefined sigma_rate: sigma_rate.shape: (N, 1)
    sigma_rate = torch.rand(N, 1, device=device)
    sigma = dist.Exponential(sigma_rate)

    # Predefined beta_mean and beta_sigma: beta_mean.shape: (K + 1, 1), beta_sigma.shape: (K + 1, 1)
    beta_mean = torch.rand(K + 1, 1, device=device)
    beta_sigma = torch.rand(K + 1, 1, device=device)
    beta = dist.Normal(beta_mean, beta_sigma)

    nu_value = nu.sample()
    nu_value.requires_grad_(True)

    sigma_value = sigma.sample()
    sigma_unconstrained_value = sigma_value.log()
    sigma_unconstrained_value.requires_grad_(True)

    beta_value = beta.sample()
    beta_value.requires_grad_(True)

    def forward(nu_value: Tensor, sigma_unconstrained_value: Tensor, beta_value: Tensor) -> Tensor:
        sigma_constrained_value = sigma_unconstrained_value.exp()
        mu = X.mm(beta_value)

        # For this model, we need to compute the following three scores:
        # We need to compute the first and second gradient of this score with respect
        # to nu_value.
        nu_score = dist.StudentT(nu_value, mu, sigma_constrained_value).log_prob(Y).sum() \
            + nu.log_prob(nu_value)



        # We need to compute the first and second gradient of this score with respect
        # to sigma_unconstrained_value.
        sigma_score = dist.StudentT(nu_value, mu, sigma_constrained_value).log_prob(Y).sum() \
            + sigma.log_prob(sigma_constrained_value) \
            + sigma_unconstrained_value



        # We need to compute the first and second gradient of this score with respect
        # to beta_value.
        beta_score = dist.StudentT(nu_value, mu, sigma_constrained_value).log_prob(Y).sum() \
            + beta.log_prob(beta_value)

        return nu_score.sum() + sigma_score.sum() + beta_score.sum()

    return forward, (nu_value.to(device), sigma_unconstrained_value.to(device), beta_value.to(device))
Пример #12
0
    def score_invscales_type(self, subid, invscales):
        """
        Compute the log-probability of each sub-stroke's scale parameter
        under the prior

        Parameters
        ----------
        subid : (nsub,) tensor
            sub-stroke ID sequence
        invscales : (nsub,) tensor
            scale values for each sub-stroke

        Returns
        -------
        ll : (nsub,) tensor
            vector of log-likelihood scores
        """
        if self.isunif:
            raise NotImplementedError
        # make sure these are vectors
        assert len(invscales.shape) == 1
        assert len(subid.shape) == 1
        assert len(invscales) == len(subid)
        # create gamma distribution
        gamma = dist.Gamma(self.scales_con[subid], self.scales_rate[subid])
        # score points using the gamma distribution
        ll = gamma.log_prob(invscales)

        return ll
Пример #13
0
def learn_dist(model, dataloader, args, num_samples):
    assert args.use_weights
    if args.concentration > 0:
        prior_weight_dist = D.Gamma(
            torch.tensor([args.concentration / args.T]), torch.tensor([1.0]))
    params = []
    for j in range(num_samples):
        for i, data in enumerate(dataloader):
            x, y, weights = data
            y = y.to(args.device)
            weights = weights.squeeze().to(args.device)
            x = x.view(x.shape[0], -1).to(args.device)
            # do we need to sample from the prior?
            if args.concentration > 0:
                u_samples, x_samples = model.generate_prior_samples(args.T)
                u_weights = prior_weight_dist.sample(torch.Size([args.T
                                                                 ])).squeeze()
                Z = torch.sum(weights) + torch.sum(u_weights)
                weights = weights / Z
                u_weights = u_weights / Z
            else:
                weights = weights / torch.sum(weights)
                u_samples = None
                u_weights = None
                x_samples = None
            # get weighted mean, covariance
            mean, cov = model.weighted_ml_params(
                x, weights, y if args.cond_label_size else None, u_samples,
                u_weights)
            params.append([mean, cov])

    return params
Пример #14
0
    def sample_gamma(self, shape, scale):
        augment = 10
        # get Gamma(shape + factor, 1)
        with torch.no_grad():
            sample = distributions.Gamma(shape + augment, 1).sample()
            eps = torch.sqrt(9. * (shape + augment) - 3.) * ((
                (sample / (shape + augment - (1. / 3.)))**(1. / 3.)) - 1.)

        z = (shape + augment -
             (1. / 3.)) * ((1. +
                            (eps / torch.sqrt(9. *
                                              (shape + augment) - 3.)))**3.)

        # reduce factor
        with torch.no_grad():
            expand_shape = shape.unsqueeze(-1).repeat(1, 1, augment)
            factor_range = torch.arange(
                1, augment + 1, dtype=torch.float,
                device=self.device).expand_as(expand_shape)
            u = distributions.Uniform(
                torch.zeros(factor_range.size(), device=self.device),
                torch.ones(factor_range.size(), device=self.device)).sample()

        u_prod = torch.prod(
            u**(1. / (expand_shape + factor_range - 1. + 1e-12)), -1)
        z = z * u_prod * scale

        return z
Пример #15
0
 def __init__(self, dim, sigma=None, lambd=None, k=1, j=0, device='cpu'):
     self.k = k
     self.j = j
     super().__init__(dim, sigma, lambd, device)
     self.gamma_dist = D.Gamma(concentration=torch.tensor((dim - j) / k,
                                                          device=device),
                               rate=1)
Пример #16
0
 def __init__(self, dim, sigma=None, lambd=None, k=1, device='cpu'):
     self.k = k
     if k <= 0:
         raise ValueError(f'k must be positive: {k} received')
     super().__init__(dim, sigma, lambd, device)
     self.l1_radii = self.l1_rho = self._l1_table_info = None
     self.gamma_dist = D.Gamma(concentration=torch.tensor(dim / k,
                                                          device=device),
                               rate=1)
    def _resample_std(self):
        # Non-negative constraints
        W_shape = F.softplus(self.W_shape)
        W_rate = F.softplus(self.W_rate)
        b_shape = F.softplus(self.b_shape)
        b_rate = F.softplus(self.b_rate)

        # Resample variances
        W_gamma_dist = dist.Gamma(W_shape, W_rate)
        b_gamma_dist = dist.Gamma(b_shape, b_rate)

        inv_W_var = W_gamma_dist.rsample()
        inv_b_var = b_gamma_dist.rsample()

        W_std = 1. / (torch.sqrt(inv_W_var) + self.eps)
        b_std = 1. / (torch.sqrt(inv_b_var) + self.eps)

        return W_std, b_std
Пример #18
0
 def sample(self, size):
     if type(size) == int:
         sample_size = (size, self.dim)
     else:
         sample_size = list(size) + [self.dim]
     signs = 2.0 * distributions.Bernoulli(probs=0.75).sample(sample_size) - 1.0
     gammas = distributions.Gamma(concentration=self.shape, rate=self.rate).sample(
         sample_size
     )
     return signs * gammas
Пример #19
0
def simulate_XZI_seq(Params, device=device, noise_X=1.):
    t = torch.tensor([0.]).to(device)
    X_s = [transform_x(t)]
    Z_s = [sample_using_logits(Params.Wx * X_s[0])]
    I_s = [dist.Gamma(*link_gamma(X_s[0], Z_s[0], Params)).sample()]

    for n in range(1, Params.N):
        t += I_s[-1]
        X_s.append(
            transform_x(t) + torch.randn((1, ), device=device) * Params.noise_X
        )  #uses current t which is a function of all I_{t'}, t' < t
        Z_s.append(sample_using_logits(Params.Wx * X_s[-1] +
                                       Params.P[Z_s[-1]]))  #X_{t} and Z_{t-1}
        I_s.append(dist.Gamma(
            *link_gamma(X_s[-1], Z_s[-1], Params)).sample())  #X_{t} and Z_{t}

    X_s, Z_s, I_s = torch.tensor(X_s).to(device), torch.tensor(Z_s).to(
        device), torch.tensor(I_s).to(device)
    return X_s, Z_s, I_s
Пример #20
0
 def __init__(self, dim, sigma=None, lambd=None, p=1, device='cpu'):
     self.p = p
     if p <= 0:
         raise ValueError(f'p must be positive: {p} received')
     super().__init__(dim, sigma, lambd, device)
     if p <= 1:
         self.l1_radii = self.l1_rho = self._l1_table_info = None
     self.gamma_dist = D.Gamma(concentration=torch.tensor(1 / p,
                                                          device=device),
                               rate=1)
Пример #21
0
    def _sample_std(self, shape, rate):
        with torch.no_grad():
            shape = F.softplus(shape)
            rate = F.softplus(rate)

            gamma_dist = dist.Gamma(shape, rate)
            inv_var = gamma_dist.rsample()
            std = 1. / (torch.sqrt(inv_var) + self.eps)

            return F.softplus(std)
Пример #22
0
 def rsample(self, sample_shape=torch.Size()):
     # X_k | f ~ Gamma(alpha_{f,k}, 1)
     X = td.Gamma(self._concentration, torch.ones_like(self._concentration), validate_args=False)  
     # [batch_size, K]
     x = X.rsample(sample_shape)
     # now we mask the Gamma samples from invalid coordinates of lower-dimensional faces
     x = torch.where(self._mask, x, torch.zeros_like(x))  
     # finally, we renormalise the gamma samples
     z = x / x.sum(-1, keepdim=True)
     return z
Пример #23
0
    def __init__(self, multiplicity, in_features, dropout=0.0):
        """Creat a gamma layer.

        Args:
            multiplicity: Number of parallel representations for each input feature.
            in_features: Number of input features.

        """
        super().__init__(multiplicity, in_features, dropout)
        self.concentration = nn.Parameter(torch.rand(1, in_features, multiplicity))
        self.rate = nn.Parameter(torch.rand(1, in_features, multiplicity))
        self.gamma = dist.Gamma(concentration=self.concentration, rate=self.rate)
Пример #24
0
 def posterior_predictive_std(self,
                              x: torch.Tensor,
                              exact: bool = True) -> torch.Tensor:
     if exact:
         mean_precision = self.α(x) / self.β(x)
         σ = 1 / torch.sqrt(mean_precision + self.eps)
     else:
         qp = D.Gamma(self.α(x), self.β(x))
         samples_precision = qp.rsample(torch.Size([self.n_mc_samples]))
         precision = torch.mean(samples_precision, 0, True)
         σ = 1 / torch.sqrt(precision)
     return σ
Пример #25
0
    def __init__(self, in_features: int, out_channels: int, num_repetitions: int = 1, dropout=0.0):
        """Creat a gamma layer.

        Args:
            out_channels: Number of parallel representations for each input feature.
            in_features: Number of input features.

        """
        super().__init__(in_features, out_channels, num_repetitions, dropout)
        self.concentration = nn.Parameter(torch.rand(1, in_features, out_channels, num_repetitions))
        self.rate = nn.Parameter(torch.rand(1, in_features, out_channels, num_repetitions))
        self.gamma = dist.Gamma(concentration=self.concentration, rate=self.rate)
Пример #26
0
    def sample_params(self, n_sample=torch.Size([])):
        clusters = self.cluster_distr.rsample(n_sample)
        params = self.cluster_to_params_graph(clusters)

        alpha_hsl0 = F.softplus(params[0:3])
        beta_hsl0 = F.softplus(params[3:6])
        hsl0 = td.Beta(alpha_hsl0, beta_hsl0).rsample()

        alpha_hsl1 = F.softplus(params[6:9])
        beta_hsl1 = F.softplus(params[9:12])
        hsl1 = td.Beta(alpha_hsl1, beta_hsl1).rsample()

        shape_trans01 = F.softplus(params[12:15])
        scale_trans01 = F.softplus(params[15:18])
        trans01 = td.Gamma(shape_trans01, scale_trans01).rsample()

        shape_trans10 = F.softplus(params[18:21])
        scale_trans10 = F.softplus(params[21:24])
        trans10 = td.Gamma(shape_trans10, scale_trans10).rsample()

        return hsl0, hsl1, trans01, trans10
Пример #27
0
 def _forward(self, x_n, z_n):
     """
     x_n - shape=(BS,N)
     z_n - shape=(BS,N)
     """
     x_n = x_n.view(*x_n.shape, 1)  #shape = (BS,N,1)
     z_n_one_hot = self._make_one_hot(z_n)  #shape= (BS,N,B)
     glm_input = torch.cat([x_n, z_n_one_hot], dim=-1)
     gamma_params = self.glm(glm_input).mul(-1).exp()  #shape= (BS,N,2)
     num_dims = len(gamma_params.shape)
     # from docs: tensor.select(0, index) is equivalent to tensor[index] and tensor.select(2, index) is equivalent to tensor[:,:,index].
     a, b = gamma_params.select(num_dims - 1,
                                0), gamma_params.select(num_dims - 1, 1)
     # gamma_params[*prev_shape,0], gamma_params[*prev_shape,1]
     dist_i = dist.Gamma(a, b)
     return dist_i
Пример #28
0
 def __init__(self, dim, sigma=None, lambd=None, k=1, j=0, device='cpu'):
     self.k = k
     self.j = j
     super().__init__(dim, sigma, lambd, device)
     if dim > 1:
         self.gamma_factor = dim / (dim - 1) * math.exp(
             math.lgamma((dim - j) / k) - math.lgamma((dim - j - 1) / k))
     elif j == 0:
         self.gamma_factor = math.exp(
             math.lgamma((dim + k) / k) - math.lgamma((dim + k - 1) / k))
     else:
         raise ValueError(
             f'ExpInf(dim={dim}, k={k}, j={j}) is not a distribution.')
     self.gamma_dist = D.Gamma(concentration=torch.tensor((dim - j) / k,
                                                          device=device),
                               rate=1)
Пример #29
0
    def decoder(self, z):
        x_mu = self.dec_mu(z)
        if self.switch:
            d = dist(z, self.C)
            d_min = d.min(dim=1, keepdim=True)[0]
            s = translatedSigmoid(d_min, -6.907 * 0.3, 0.3)
            alpha = self.alpha(z)
            beta = self.beta(z)
            gamma_dist = D.Gamma(alpha + 1e-6, beta + 1e-6)
            samples_var = gamma_dist.rsample([20])
            x_var = (1.0 / (samples_var + 1e-6))
            x_var = (1 - s) * x_var + s * (self.fixed_var *
                                           torch.ones_like(x_var))
        else:
            x_var = (0.02**2) * torch.ones_like(x_mu)

        return x_mu, x_var
Пример #30
0
        def forward(self, x, switch):
            d = dist(x, c)
            d_min = d.min(dim=1, keepdim=True)[0]
            s = self.trans(d_min)
            mean = self.mean(x)
            if switch:
                a = self.alph(x)
                b = self.bet(x)
                gamma_dist = D.Gamma(a+1e-8, 1.0/(b+1e-8))
                if self.training:
                    samples_var = gamma_dist.rsample(torch.Size([num_draws_train]))
                    x_var = (1.0/(samples_var+1e-8))
                else:
                    samples_var = gamma_dist.rsample(torch.Size([2000]))
                    x_var = (1.0/(samples_var+1e-8))
                var = (1-s) * x_var + s * y_std ** 2

            else:
                var = 0.05*torch.ones_like(mean)
            return mean, var