Esempio n. 1
0
    def sample_from_beta_distribution(
        self,
        alpha: torch.Tensor,
        beta: torch.Tensor,
        eps_gamma: float = 1e-30,
        eps_sample: float = 1e-7,
    ) -> torch.Tensor:
        # Sample from a Beta distribution using the reparameterization trick.
        # Problem : it is not implemented in CUDA yet
        # Workaround : sample X and Y from Gamma(alpha,1) and Gamma(beta,1), the Beta sample is X/(X+Y)
        # Warning : use logs and perform logsumexp to avoid numerical issues

        # Sample from Gamma
        sample_x_log = torch.log(Gamma(alpha, 1).rsample() + eps_gamma)
        sample_y_log = torch.log(Gamma(beta, 1).rsample() + eps_gamma)

        # Sum using logsumexp (note : eps_gamma is used to prevent numerical issues with perfect
        # 0 and 1 final Beta samples
        sample_xy_log_max = torch.max(sample_x_log, sample_y_log)
        sample_xplusy_log = sample_xy_log_max + torch.log(
            torch.exp(sample_x_log - sample_xy_log_max) +
            torch.exp(sample_y_log - sample_xy_log_max))
        sample_log = sample_x_log - sample_xplusy_log
        sample = eps_sample + (1 - 2 * eps_sample) * torch.exp(sample_log)

        return sample
Esempio n. 2
0
    def gen_default_priors(self, data, K, L,
                           # sig_prior=Gamma(10, 10),
                           sig_prior=LogNormal(0, 0.01),
                           alpha_prior=Gamma(1, 1),
                           mu0_prior=None,
                           mu1_prior=None,
                           W_prior=None,
                           eta0_prior=None,
                           eta1_prior=None):

        if L is None:
            L = [5, 5]

        self.__cache_model_constants__(data, K, L)

        if mu0_prior is None:
            mu0_prior = Gamma(1, 1)

        if mu1_prior is None:
            mu1_prior = Gamma(1, 1)

        if W_prior is None:
            W_prior = Dirichlet(torch.ones(self.K) / self.K)

        if eta0_prior is None:
            eta0_prior = Dirichlet(torch.ones(self.L[0]) / self.L[0])

        if eta1_prior is None:
            eta1_prior = Dirichlet(torch.ones(self.L[1]) / self.L[1])

        self.priors = {'mu0': mu0_prior, 'mu1': mu1_prior, 'sig': sig_prior,
                       'eta0': eta0_prior, 'eta1': eta1_prior, 'W': W_prior,
                       'alpha': alpha_prior}
Esempio n. 3
0
 def forward(self, x, b):
     logN = torch.log(x.sum(axis=-1)).view(-1, 1)
     varz = torch.stack([self.variational_logvars] * len(logN))
     varz = torch.cat((varz, logN), dim=1)
     z_mean = self.encode(x, b)
     z_var = self.sigma_net(varz)
     gam = F.softplus(self.loggamma(b), beta=0.1)
     phi = F.softplus(self.logphi(b), beta=0.1)
     qz = Normal(z_mean, torch.exp(0.5 * z_var))
     ql = Normal(0, torch.exp(0.5 * self.log_sigma_sq))
     qb = Normal(self.beta(b), torch.exp(0.5 * self.beta_logvars(b)))
     qS = Gamma(gam, phi)
     # draw differentiable MC samples
     z_sample = qz.rsample()
     b_sample = rnormalgamma(qb, qS)
     l_sample = ql.rsample()
     # compute KL divergence + reconstruction loss
     x_out = self.decoder(z_sample) + b_sample + l_sample
     kl_div_z = kl_divergence(qz, Normal(0, 1)).mean(0).sum()
     kl_div_b = kl_divergence(qb, Normal(0, self.bpr)).mean(0).sum()
     kl_div_S = kl_divergence(qS, Gamma(self.gpr, self.ppr)).mean(0).sum()
     recon_loss = self.recon_model_loglik(x, x_out).mean(0).sum()
     elbo = recon_loss - kl_div_z - kl_div_b - kl_div_S
     loss = - elbo
     return loss, -recon_loss, kl_div_z, kl_div_b, kl_div_S
Esempio n. 4
0
def log_prior(Z, variable_types):
    if 'pi_unconstrained' in variable_types.keys():
        pi = softmax(Z['pi_unconstrained'][0], dim=-1)
        logp = Dirichlet(
            torch.ones_like(pi)).log_prob(pi) + log_det_jacobian_softmax(
                pi, dim=-1)
    else:
        logp = 0
    for i, (key, z) in enumerate(Z.items()):
        z = z[0]
        if key != 'pi_unconstrained':
            if variable_types[key] == 'Categorical':
                alpha = softmax(z, dim=-1, additional=-50.)
                logp += torch.sum(Dirichlet(torch.ones_like(alpha)).log_prob(alpha) \
                  + log_det_jacobian_softmax(alpha, dim=-1), dim=-1)
            #elif variable_types[key] == 'Bernoulli':
            #	theta = torch.sigmoid(z)
            #	logp += torch.sum(Beta(torch.ones_like(theta), torch.ones_like(theta)).log_prob(theta)\
            #			+ log_det_jacobian_sigmoid(theta), dim=-1)
            elif variable_types[key] == 'Bernoulli':
                logp += TransBeta.log_prob(z).sum()
            elif variable_types[key] == 'Beta':
                alpha, beta = torch.exp(z)
                logp += torch.sum(Gamma(1.0, 1.0).log_prob(alpha) +
                                  torch.log(alpha),
                                  dim=-1)
                logp += torch.sum(Gamma(1.0, 1.0).log_prob(beta) +
                                  torch.log(beta),
                                  dim=-1)
    return torch.mean(logp)
Esempio n. 5
0
    def KLD(self, a, b, prior_alpha, prior_beta):
        eps = 5 * torch.finfo(torch.float).eps 
        a = a.clamp(eps)
        b = b.clamp(eps)
        if self.dist == "km":
            ab = (a * b) + eps
            kl = 1 / (1 + ab) * self.Beta(1 / a, b)
            kl += 1 / (2 + ab) * self.Beta(2 / a, b)
            kl += 1 / (3 + ab) * self.Beta(3 / a, b)
            kl += 1 / (4 + ab) * self.Beta(4 / a, b)
            kl += 1 / (5 + ab) * self.Beta(5 / a, b)
            kl += 1 / (6 + ab) * self.Beta(6 / a, b)
            kl += 1 / (7 + ab) * self.Beta(7 / a, b)
            kl += 1 / (8 + ab) * self.Beta(8 / a, b)
            kl += 1 / (9 + ab) * self.Beta(9 / a, b)
            kl += 1 / (10 + ab) * self.Beta(10 / a, b)
            kl *= (prior_beta - 1) * b

            kl += (a - prior_alpha) / a * (-np.euler_gamma - torch.digamma(
                b) - 1 / b)  # T.psi(self.posterior_b)

            # add normalization constants
            kl += torch.log(ab) + torch.log(self.Beta(prior_alpha, prior_beta))

            # final term
            kl += -(b - 1) / b
        elif self.dist == "gamma":
            kl = torch.distributions.kl.kl_divergence(Gamma(a, b), Gamma(prior_alpha, prior_beta))
        elif self.dist == "gl":
            prior_alpha_beta = prior_alpha/(prior_alpha + prior_beta)
            prior_beta_beta =  torch.sqrt(prior_alpha*prior_beta / ((prior_alpha + prior_beta)**2*(prior_alpha + prior_beta + 1)))
            kl = torch.distributions.kl.kl_divergence(Independent(Normal(a, b),1), Independent(Normal(prior_alpha_beta, prior_beta_beta),1)).unsqueeze(1)
        return kl
Esempio n. 6
0
    def test_gamma_prior_batch_log_prob(self, cuda=False):
        device = torch.device("cuda") if cuda else torch.device("cpu")

        concentration = torch.tensor([1.0, 2.0], device=device)
        rate = torch.tensor([1.0, 2.0], device=device)
        prior = GammaPrior(concentration, rate)
        dist = Gamma(concentration, rate)
        t = torch.ones(2, device=device)
        self.assertTrue(torch.equal(prior.log_prob(t), dist.log_prob(t)))
        t = torch.ones(2, 2, device=device)
        self.assertTrue(torch.equal(prior.log_prob(t), dist.log_prob(t)))
        with self.assertRaises(RuntimeError):
            prior.log_prob(torch.ones(3, device=device))

        mean = torch.tensor([[1.0, 2.0], [0.5, 3.0]], device=device)
        variance = torch.tensor([[1.0, 2.0], [0.5, 1.0]], device=device)
        prior = GammaPrior(mean, variance)
        dist = Gamma(mean, variance)
        t = torch.ones(2, device=device)
        self.assertTrue(torch.equal(prior.log_prob(t), dist.log_prob(t)))
        t = torch.ones(2, 2, device=device)
        self.assertTrue(torch.equal(prior.log_prob(t), dist.log_prob(t)))
        with self.assertRaises(RuntimeError):
            prior.log_prob(torch.ones(3, device=device))
        with self.assertRaises(RuntimeError):
            prior.log_prob(torch.ones(2, 3, device=device))
Esempio n. 7
0
def log_prior(Z, variable_types):
    """
	Z : A dictionary containing the draws from variational posterior for each parameter.
	variable_types : a dictionary that contains distribution name assigned to each parameter.
	"""
    ## We proceed similarly as in log-likelihood computation, however since the elements of
    ## Z are in expanded form and the prior is not data dependent we compute the contribution
    ## of only the first element of element of Z.
    pi = softmax(Z['pi_unconstrained'][0], dim=-1)
    logp = Dirichlet(
        torch.ones_like(pi)).log_prob(pi) + log_det_jacobian_softmax(pi,
                                                                     dim=-1)
    for i, (key, z) in enumerate(Z.items()):
        if key != 'pi_unconstrained':
            z = z[0]
            if variable_types[key] == 'Categorical':
                alpha = softmax(z, dim=-1, additional=-50.)
                logp += torch.sum(Dirichlet(torch.ones_like(alpha)).log_prob(alpha) \
                  + log_det_jacobian_softmax(alpha, dim=-1), dim=-1)
            elif variable_types[key] == 'Bernoulli':
                theta = torch.sigmoid(z)
                logp += torch.sum(Beta(torch.ones_like(theta), torch.ones_like(theta)).log_prob(theta)\
                  + log_det_jacobian_sigmoid(theta), dim=-1)
            elif variable_types[key] == 'Beta':
                alpha, beta = torch.exp(z)
                logp += torch.sum(Gamma(1.0, 1.0).log_prob(alpha) +
                                  torch.log(alpha),
                                  dim=-1)
                logp += torch.sum(Gamma(1.0, 1.0).log_prob(beta) +
                                  torch.log(beta),
                                  dim=-1)
    return logp
Esempio n. 8
0
    def _forward(self, shape, rate):
        shape_like = self.log_shape_like.exp()
        rate_like = self.log_rate_like.exp()

        post_shape = shape + shape_like
        post_rate  = rate  + rate_like

        return Gamma(shape, rate), Gamma(post_shape, post_rate)
Esempio n. 9
0
 def get_kl(self):
     gamma_q = Gamma(concentration=self.logalpha.exp(), rate=self.logbeta.exp())
     gamma_p = Gamma(0.1*torch.ones_like(self.logalpha), 0.3*torch.ones_like(self.logalpha))
     beta_q = Beta(self.logtheta.exp(), self.logeta.exp())
     beta_p = Beta(torch.ones_like(self.logtheta), torch.ones_like(self.logtheta))
     # kl = _kl_beta_beta(beta_q, beta_p) + _kl_gamma_gamma(gamma_q, gamma_p)
     kl = kl_divergence(beta_q, beta_p).sum() + kl_divergence(gamma_q, gamma_p).sum()
     return kl
Esempio n. 10
0
    def __init__(self,
                 event_shape: torch.Size,
                 df=3.,
                 loc=0.,
                 covariance_matrix=None,
                 precision_matrix=None,
                 scale_tril=None,
                 validate_args=None):
        super().__init__(loc=loc,
                         covariance_matrix=covariance_matrix,
                         precision_matrix=precision_matrix,
                         scale_tril=scale_tril,
                         validate_args=validate_args)

        # self._event_shape is inferred from the mean vector and covariance matrix.
        old_event_shape = self._event_shape
        if not len(event_shape) >= len(old_event_shape):
            raise NotImplementedError("non-elliptical MVT not in this class")
        assert len(event_shape) >= 1
        assert event_shape[-len(old_event_shape):] == old_event_shape

        # Cut dimensions from the end of `batch_shape` so the `total_shape` is
        # the same
        total_shape = list(self._batch_shape) + list(self._event_shape)
        self._batch_shape = torch.Size(total_shape[:-len(event_shape)])
        self._event_shape = torch.Size(event_shape)

        self.df, _ = broadcast_all(df, torch.ones(self._batch_shape))
        self.gamma = Gamma(concentration=self.df / 2., rate=1 / 2)
Esempio n. 11
0
def draw_samples(path, name, model, test_loader, train_loader, method):
    sample = None
    if path.startswith('./assets/data/sbvae'):
        if method == 'km':
            sample = Beta(torch.tensor([1.0]),
                          torch.tensor([5.0])).rsample([64, 50
                                                        ]).squeeze().to(device)
        elif method == 'gamma':
            sample = Gamma(torch.tensor([1.0]),
                           torch.tensor([5.0
                                         ])).rsample([64,
                                                      50]).squeeze().to(device)
        cumprods = torch.cat((torch.ones(
            [64, 1], device=device), torch.cumprod(1 - sample, axis=1)),
                             dim=1)
        sample = cumprods[:, :-1] * sample
        sample[:, -1] = 1 - sample[:, :-1].sum(axis=1)
    else:
        sample = torch.randn(64, 50)
    sample = torch.sigmoid(model.decode(sample)).reshape(
        64, 28, 28).cpu().detach().numpy()
    f, axarr = plt.subplots(8, 8)
    for i in range(64):
        axarr[i // 8, i % 8].imshow(sample[i])
    plt.savefig(path + ".samples.pdf")
Esempio n. 12
0
 def test_gamma_sample_grad(self):
     self._set_rng_seed(1)
     num_samples = 100
     for alpha in [1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4]:
         alphas = Variable(torch.Tensor([alpha] * num_samples),
                           requires_grad=True)
         betas = Variable(torch.ones(num_samples))
         x = Gamma(alphas, betas).sample()
         x.sum().backward()
         x, ind = x.data.sort()
         x = x.numpy()
         actual_grad = alphas.grad.data[ind].numpy()
         # Compare with expected gradient dx/dalpha along constant cdf(x,alpha).
         cdf = scipy.stats.gamma.cdf
         pdf = scipy.stats.gamma.pdf
         eps = 0.02 * alpha if alpha < 100 else 0.02 * alpha**0.5
         cdf_alpha = (cdf(x, alpha + eps) - cdf(x, alpha - eps)) / (2 * eps)
         cdf_x = pdf(x, alpha)
         expected_grad = -cdf_alpha / cdf_x
         rel_error = np.abs(actual_grad - expected_grad) / (expected_grad +
                                                            1e-100)
         self.assertLess(
             np.max(rel_error), 0.005, '\n'.join([
                 'Bad gradients for Gamma({}, 1)'.format(alpha),
                 'x {}'.format(x), 'expected {}'.format(expected_grad),
                 'actual {}'.format(actual_grad),
                 'rel error {}'.format(rel_error),
                 'max error {}'.format(rel_error.max())
             ]))
Esempio n. 13
0
 def __init__(self, concentration, rate, validate_args=None):
     base_dist = Gamma(concentration, rate)
     super().__init__(
         base_dist,
         PowerTransform(-base_dist.rate.new_ones(())),
         validate_args=validate_args,
     )
Esempio n. 14
0
 def get_log_prob(self, state, action):
     with torch.no_grad():
         concentration = self._get_concentration(state)
         # This is not getting exported; we can use it
         dist = Gamma(concentration, rate=torch.ones(concentration.shape))
         log_prob = dist.log_prob(action)
     return log_prob
Esempio n. 15
0
    def generate_denoised_samples(
        self,
        n_samples: int = 25,
        batch_size: int = 64,
        rna_size_factor: int = 1,
        transform_batch: Optional[int] = None,
    ):
        """ Return samples from an adjusted posterior predictive. Proteins are concatenated to genes.

        :param n_samples: How may samples per cell
        :param batch_size: Mini-batch size for sampling. Lower means less GPU memory footprint
        :rna_size_factor: size factor for RNA prior to sampling gamma distribution
        :transform_batch: int of which batch to condition on for all cells
        :return:
        """
        posterior_list = []
        for tensors in self.update({"batch_size": batch_size}):
            x, _, _, batch_index, labels, y = tensors
            with torch.no_grad():
                outputs = self.model.inference(
                    x,
                    y,
                    batch_index=batch_index,
                    label=labels,
                    n_samples=n_samples,
                    transform_batch=transform_batch,
                )
            px_ = outputs["px_"]
            py_ = outputs["py_"]

            pi = 1 / (1 + torch.exp(-py_["mixing"]))
            mixing_sample = Bernoulli(pi).sample()
            protein_rate = py_["rate_fore"]
            rate = torch.cat((rna_size_factor * px_["scale"], protein_rate),
                             dim=-1)
            if len(px_["r"].size()) == 2:
                px_dispersion = px_["r"]
            else:
                px_dispersion = torch.ones_like(x) * px_["r"]
            if len(py_["r"].size()) == 2:
                py_dispersion = py_["r"]
            else:
                py_dispersion = torch.ones_like(y) * py_["r"]

            dispersion = torch.cat((px_dispersion, py_dispersion), dim=-1)

            # This gamma is really l*w using scVI manuscript notation
            p = rate / (rate + dispersion)
            r = dispersion
            l_train = Gamma(r, (1 - p) / p).sample()
            data = l_train.cpu().numpy()
            # make background 0
            data[:, :, self.gene_dataset.nb_genes:] = (
                data[:, :, self.gene_dataset.nb_genes:] *
                (1 - mixing_sample).cpu().numpy())
            posterior_list += [data]

            posterior_list[-1] = np.transpose(posterior_list[-1], (1, 2, 0))

        return np.concatenate(posterior_list, axis=0)
Esempio n. 16
0
def model(data, **kwargs):
    with pyro.plate("beta_plate", T - 1):
        beta = pyro.sample("beta", Beta(1, alpha))

    zeta = 2. * torch.ones(T * D, device=device)
    delta = 2. * torch.ones(T * D, device=device)
    with pyro.plate("prec_plate", T * D):
        prec = pyro.sample("prec", Gamma(zeta, delta))

    corr_chol = torch.zeros(T, D, D)
    for t in pyro.plate("corr_chol_plate", T):
        corr_chol[t, ...] = pyro.sample(
            "corr_chol_{}".format(t),
            LKJCorrCholesky(d=D, eta=torch.ones(1, device=device)))

    with pyro.plate("mu_plate", T):
        _std = torch.sqrt(1. / prec.view(-1, D))
        sigma_chol = torch.bmm(torch.diag_embed(_std), corr_chol)
        mu = pyro.sample(
            "mu",
            MultivariateNormal(torch.zeros(T, D, device=device),
                               scale_tril=sigma_chol))

    with pyro.plate("data", N):
        z = pyro.sample("z", Categorical(mix_weights(beta)))
        pyro.sample("obs",
                    MultivariateNormal(mu[z], scale_tril=sigma_chol[z]),
                    obs=data)
def test_gamma_likelihood(concentration: float, rate: float) -> None:
    """
    Test to check that maximizing the likelihood recovers the parameters
    """

    # generate samples
    concentrations = torch.zeros((NUM_SAMPLES,)) + concentration
    rates = torch.zeros((NUM_SAMPLES,)) + rate

    distr = Gamma(concentrations, rates)
    samples = distr.sample()

    init_biases = [
        inv_softplus(concentration - START_TOL_MULTIPLE * TOL * concentration),
        inv_softplus(rate - START_TOL_MULTIPLE * TOL * rate),
    ]

    concentration_hat, rate_hat = maximum_likelihood_estimate_sgd(
        GammaOutput(),
        samples,
        init_biases=init_biases,
        learning_rate=PositiveFloat(0.05),
        num_epochs=PositiveInt(10),
    )

    assert (
        np.abs(concentration_hat - concentration) < TOL * concentration
    ), f"concentration did not match: concentration = {concentration}, concentration_hat = {concentration_hat}"
    assert (
        np.abs(rate_hat - rate) < TOL * rate
    ), f"rate did not match: rate = {rate}, rate_hat = {rate_hat}"
Esempio n. 18
0
    def __init__(self, seed=None, gamma_k=2.0, gamma_loc=0.0, normal_mean=5.0, normal_sigma=0.5, cuda=False,
                     background_luminosity=1000, signal_luminosity=1000):
        self.seed = seed
        if cuda:
            self.cuda()
        else:
            self.cpu()
        config = GGConfig()
        self.rescale = self.tensor(config.CALIBRATED.rescale, requires_grad=True)
        self.mu = self.tensor(config.CALIBRATED.mu, requires_grad=True)
        self.nuisance_params = OrderedDict([
                                ('rescale', self.rescale),
                                ])
        # Define distributions
        self.gamma_k      = self.tensor(gamma_k)
        self.gamma_loc    = self.tensor(gamma_loc)
        self.gamma_rate   = self.tensor(1.0)

        self.normal_mean  = self.tensor(normal_mean)
        self.normal_sigma = self.tensor(normal_sigma)

        self.gamma = Gamma(self.gamma_k, self.gamma_rate)
        self.norm  = Normal(self.normal_mean, self.normal_sigma)

        self.background_luminosity = background_luminosity
        self.signal_luminosity = signal_luminosity
        self.n_expected_events = background_luminosity + signal_luminosity
Esempio n. 19
0
    def gen_default_priors(self,
                           K,
                           L,
                           sig_prior=LogNormal(0, 1),
                           alpha_prior=Gamma(1., 1.),
                           mu0_prior=None,
                           mu1_prior=None,
                           W_prior=None,
                           eta0_prior=None,
                           eta1_prior=None):

        if L is None:
            L = [5, 5]

        self.L = L

        if K is None:
            K = 30

        self.K = K

        # FIXME: these should be an ordered prior of TruncatedNormals
        if mu0_prior is None:
            mu0_prior = Gamma(1, 1)

        if mu1_prior is None:
            mu1_prior = Gamma(1, 1)

        if W_prior is None:
            W_prior = Dirichlet(torch.ones(self.K) / self.K)

        if eta0_prior is None:
            eta0_prior = Dirichlet(torch.ones(self.L[0]) / self.L[0])

        if eta1_prior is None:
            eta1_prior = Dirichlet(torch.ones(self.L[1]) / self.L[1])

        self.priors = {
            'mu0': mu0_prior,
            'mu1': mu1_prior,
            'sig': sig_prior,
            'H': Normal(0, 1),
            'eta0': eta0_prior,
            'eta1': eta1_prior,
            'W': W_prior,
            'alpha': alpha_prior
        }
Esempio n. 20
0
    def __init__(self,
                 dim,
                 act=nn.ReLU(),
                 num_hiddens=[50],
                 nout=1,
                 conf=dict()):
        nn.Module.__init__(self)
        BNN.__init__(self)
        self.dim = dim
        self.act = act
        self.num_hiddens = num_hiddens
        self.nout = nout
        self.steps_burnin = conf.get('steps_burnin', 2500)
        self.steps = conf.get('steps', 2500)
        self.keep_every = conf.get('keep_every', 50)
        self.batch_size = conf.get('batch_size', 32)
        self.warm_start = conf.get('warm_start', False)

        self.lr_weight = np.float32(conf.get('lr_weight', 1e-3))
        self.lr_noise = np.float32(conf.get('lr_noise', 1e-3))
        self.lr_lambda = np.float32(conf.get('lr_lambda', 1e-3))
        self.alpha_w = torch.as_tensor(1. * conf.get('alpha_w', 6.))
        self.beta_w = torch.as_tensor(1. * conf.get('beta_w', 6.))
        self.alpha_n = torch.as_tensor(1. * conf.get('alpha_n', 6.))
        self.beta_n = torch.as_tensor(1. * conf.get('beta_n', 6.))
        self.noise_level = conf.get('noise_level', None)
        if self.noise_level is not None:
            prec = 1 / self.noise_level**2
            prec_var = (prec * 0.25)**2
            self.beta_n = torch.as_tensor(prec / prec_var)
            self.alpha_n = torch.as_tensor(prec * self.beta_n)
            print("Reset alpha_n = %g, beta_n = %g" %
                  (self.alpha_n, self.beta_n))

        self.prior_log_lambda = TransformedDistribution(
            Gamma(self.alpha_w, self.beta_w),
            ExpTransform().inv)  # log of gamma distribution
        self.prior_log_precision = TransformedDistribution(
            Gamma(self.alpha_n, self.beta_n),
            ExpTransform().inv)

        self.log_lambda = nn.Parameter(torch.tensor(0.))
        self.log_precs = nn.Parameter(torch.zeros(self.nout))
        self.nn = NN(dim, self.act, self.num_hiddens, self.nout)

        self.init_nn()
Esempio n. 21
0
 def test_gamma_shape_tensor_params(self):
     gamma = Gamma(torch.Tensor([1, 1]), torch.Tensor([1, 1]))
     self.assertEqual(gamma._batch_shape, torch.Size((2,)))
     self.assertEqual(gamma._event_shape, torch.Size(()))
     self.assertEqual(gamma.sample().size(), torch.Size((2,)))
     self.assertEqual(gamma.sample((3, 2)).size(), torch.Size((3, 2, 2)))
     self.assertEqual(gamma.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
     self.assertRaises(ValueError, gamma.log_prob, self.tensor_sample_2)
Esempio n. 22
0
 def sample(self):
     z = Beta(concentration0=self.logtheta.exp().detach(),
              concentration1=self.logeta.exp().detach()).sample()
     logit_z = torch.log(z / (1 - z))
     w = Gamma(concentration=self.logalpha.exp().detach(),
               rate=self.logbeta.exp().detach()).sample()
     log_w = w.log()
     return logit_z, log_w
Esempio n. 23
0
    def generate(
        self,
        n_samples: int = 100,
        batch_size: int = 64
    ):  # with n_samples>1 return original list/ otherwise sequential
        """
        Return samples from posterior predictive. Proteins are concatenated to genes.

        :param n_samples: Number of posterior predictive samples
        :return: Tuple of posterior samples, original data
        """
        original_list = []
        posterior_list = []
        for tensors in self.update({"batch_size": batch_size}):
            x, _, _, batch_index, labels, y = tensors
            with torch.no_grad():
                outputs = self.model.inference(x,
                                               y,
                                               batch_index=batch_index,
                                               label=labels,
                                               n_samples=n_samples)
            px_ = outputs["px_"]
            py_ = outputs["py_"]

            pi = 1 / (1 + torch.exp(-py_["mixing"]))
            mixing_sample = Bernoulli(pi).sample()
            protein_rate = (py_["rate_fore"] * (1 - mixing_sample) +
                            py_["rate_back"] * mixing_sample)
            rate = torch.cat((px_["rate"], protein_rate), dim=-1)
            if len(px_["r"].size()) == 2:
                px_dispersion = px_["r"]
            else:
                px_dispersion = torch.ones_like(x) * px_["r"]
            if len(py_["r"].size()) == 2:
                py_dispersion = py_["r"]
            else:
                py_dispersion = torch.ones_like(y) * py_["r"]

            dispersion = torch.cat((px_dispersion, py_dispersion), dim=-1)

            # This gamma is really l*w using scVI manuscript notation
            p = rate / (rate + dispersion)
            r = dispersion
            l_train = Gamma(r, (1 - p) / p).sample()
            data = Poisson(l_train).sample().cpu().numpy()
            # """
            # In numpy (shape, scale) => (concentration, rate), with scale = p /(1 - p)
            # rate = (1 - p) / p  # = 1/scale # used in pytorch
            # """
            original_list += [np.array(torch.cat((x, y), dim=-1).cpu())]
            posterior_list += [data]

            posterior_list[-1] = np.transpose(posterior_list[-1], (1, 2, 0))

        return (
            np.concatenate(posterior_list, axis=0),
            np.concatenate(original_list, axis=0),
        )
Esempio n. 24
0
 def sample(self, m, a):
     r = 1 / a
     p = (m * a) / (1 + (m * a))
     b = (1 - p) / p
     g = Gamma(r, b)
     g = g.sample()
     p = Poisson(g)
     z = p.sample()
     return z
Esempio n. 25
0
 def test_gamma_shape_scalar_params(self):
     gamma = Gamma(1, 1)
     self.assertEqual(gamma._batch_shape, torch.Size())
     self.assertEqual(gamma._event_shape, torch.Size())
     self.assertEqual(gamma.sample().size(), torch.Size((1,)))
     self.assertEqual(gamma.sample((3, 2)).size(), torch.Size((3, 2)))
     self.assertRaises(ValueError, gamma.log_prob, self.scalar_sample)
     self.assertEqual(gamma.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
     self.assertEqual(gamma.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
Esempio n. 26
0
 def backward(ctx, output_grad):
     w, alpha, beta = ctx.saved_tensors
     log_w = w.log()
     log_w = log_w + output_grad*w * 1  # do the update on the log scale and transform it back
     updated_w = log_w.exp()
     updated_w = updated_w.detach()
     ll = Gamma(concentration=alpha, rate=beta).log_prob(updated_w)
     ll.backward()
     return alpha.grad, beta.grad
Esempio n. 27
0
    def kl(self, phi):
        qz = self.z_mean.expand_as(self.theta)
        qlogp = torch.log(nlp_pdf(self.signed_theta, phi, tau=0.358) + 1e-4)
        gamma = Gamma(concentration=self.detached_gamma_alpha, rate=1.)
        qlogq = np.log(0.5) + gamma.log_prob(
            self.theta)  # use unsigned self.theta to compute qlogq
        kl_beta = qlogq - qlogp
        kl_z = qz * torch.log(qz / 0.1) + (1 - qz) * torch.log((1 - qz) / 0.9)
        kl = (kl_z + qz * kl_beta).sum(dim=1).mean()

        return kl
Esempio n. 28
0
 def __init__(self, data, mu, df_prior, rate_prior):
     """
     :param data: Torch Tensor
     :param mu: float
     :param df_prior: float
     :param rate_prior: float
     """
     self.prior = Gamma(torch.tensor([df_prior], dtype=torch.float64),
                        torch.tensor([rate_prior], dtype=torch.float64))
     self.mu = mu
     self.data = data
Esempio n. 29
0
    def log_like(self, Y, dT):

        prob = []
        for k in range(self.K):
            p_g = Gamma(torch.exp(self.log_ab[k][0]),
                        torch.exp(self.log_ab[k][1]))
            prob_g = p_g.log_prob(dT)
            prob.append(prob_g)
        prob = torch.stack(prob).t() + self.likelihood.mixture_prob(Y)

        return prob
Esempio n. 30
0
def main():
    # TODO 1: add plot output
    epoch = 4001
    sampler = MarsagliaTsampler(size=1)
    optimizer = optim.SGD(sampler.parameters(), lr=0.001, momentum=0.9)

    alpha, beta = 10, 1
    target_distribution = Gamma(concentration=alpha, rate=beta)

    filenames = []
    path = '/extra/yadongl10/git_project/GammaLearningResult'
    for i in range(epoch):
        # compute loss
        samples = sampler(batch_size=128)
        samples = samples[samples > 0]

        loss = CE_Gamma(
            samples, target_distribution
        )  # For MarsagliaTsampler, currently only supports beta=1

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # print intermediet results
        if i % 500 == 0:
            print('loss {}'.format(loss))
            plt.plot(np.linspace(0.001, 25, 1000),
                     target_distribution.log_prob(
                         torch.linspace(0.001, 25, 1000)).exp().tolist(),
                     label='Target: Gamma({},{})'.format(alpha, beta))
            test_samples = sampler(batch_size=1000)
            test_samples = test_samples[test_samples > 0]
            plt.hist(test_samples.tolist(),
                     bins=30,
                     normed=True,
                     histtype='step')
            plt.xlim(0, 30)
            plt.ylim(0, 0.25)
            plt.legend()
            plt.title(
                'Histogram of Samples Generated from trained models at epoch: {}'
                .format(i))
            filenames.append('{}/epoch{}.png'.format(path, i))
            plt.savefig('{}/epoch{}.png'.format(path, i), dpi=200)
            plt.close()

    # make gif
    images = []
    for filename in filenames:
        images.append(imageio.imread(filename))
    imageio.mimwrite('{}/movie.gif'.format(path), images, duration=1)