Пример #1
0
 def test_gamma_shape_tensor_params(self):
     gamma = Gamma(torch.Tensor([1, 1]), torch.Tensor([1, 1]))
     self.assertEqual(gamma._batch_shape, torch.Size((2,)))
     self.assertEqual(gamma._event_shape, torch.Size(()))
     self.assertEqual(gamma.sample().size(), torch.Size((2,)))
     self.assertEqual(gamma.sample((3, 2)).size(), torch.Size((3, 2, 2)))
     self.assertEqual(gamma.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
     self.assertRaises(ValueError, gamma.log_prob, self.tensor_sample_2)
Пример #2
0
 def backward(ctx, output_grad):
     w, alpha, beta = ctx.saved_tensors
     log_w = w.log()
     log_w = log_w + output_grad*w * 1  # do the update on the log scale and transform it back
     updated_w = log_w.exp()
     updated_w = updated_w.detach()
     ll = Gamma(concentration=alpha, rate=beta).log_prob(updated_w)
     ll.backward()
     return alpha.grad, beta.grad
Пример #3
0
 def sample(self, m, a):
     r = 1 / a
     p = (m * a) / (1 + (m * a))
     b = (1 - p) / p
     g = Gamma(r, b)
     g = g.sample()
     p = Poisson(g)
     z = p.sample()
     return z
Пример #4
0
 def test_gamma_shape_scalar_params(self):
     gamma = Gamma(1, 1)
     self.assertEqual(gamma._batch_shape, torch.Size())
     self.assertEqual(gamma._event_shape, torch.Size())
     self.assertEqual(gamma.sample().size(), torch.Size((1,)))
     self.assertEqual(gamma.sample((3, 2)).size(), torch.Size((3, 2)))
     self.assertRaises(ValueError, gamma.log_prob, self.scalar_sample)
     self.assertEqual(gamma.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
     self.assertEqual(gamma.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
Пример #5
0
 def get_kl(self):
     gamma_q = Gamma(concentration=self.logalpha.exp(),
                     rate=self.logbeta.exp())
     gamma_p = Gamma(0.1 * torch.ones_like(self.logalpha),
                     0.3 * torch.ones_like(self.logalpha))
     beta_q = Beta(self.logtheta.exp(), self.logeta.exp())
     beta_p = Beta(torch.ones_like(self.logtheta),
                   torch.ones_like(self.logtheta))
     kl = kl_divergence(beta_q, beta_p).sum() + kl_divergence(
         gamma_q, gamma_p).sum()
     return kl
Пример #6
0
    def kl(self, phi):
        qz = self.z_mean.expand_as(self.theta)
        qlogp = torch.log(nlp_pdf(self.signed_theta, phi, tau=0.358) + 1e-4)
        gamma = Gamma(concentration=self.detached_gamma_alpha, rate=1.)
        qlogq = np.log(0.5) + gamma.log_prob(
            self.theta)  # use unsigned self.theta to compute qlogq
        kl_beta = qlogq - qlogp
        kl_z = qz * torch.log(qz / 0.1) + (1 - qz) * torch.log((1 - qz) / 0.9)
        kl = (kl_z + qz * kl_beta).sum(dim=1).mean()

        return kl
Пример #7
0
 def __init__(self, data, mu, df_prior, rate_prior):
     """
     :param data: Torch Tensor
     :param mu: float
     :param df_prior: float
     :param rate_prior: float
     """
     self.prior = Gamma(torch.tensor([df_prior], dtype=torch.float64),
                        torch.tensor([rate_prior], dtype=torch.float64))
     self.mu = mu
     self.data = data
Пример #8
0
    def log_like(self, Y, dT):

        prob = []
        for k in range(self.K):
            p_g = Gamma(torch.exp(self.log_ab[k][0]),
                        torch.exp(self.log_ab[k][1]))
            prob_g = p_g.log_prob(dT)
            prob.append(prob_g)
        prob = torch.stack(prob).t() + self.likelihood.mixture_prob(Y)

        return prob
Пример #9
0
 def fill_gamma(self, random=False, sequence=[]):
     distribution = Gamma(torch.tensor([4.0]), torch.tensor([0.5]))
     if random:
         for i in range(self.initial_fill):
             val = distribution.sample()
             val = (val / 20.0) * (self.kind_cars - 1)
             if ((val > self.kind_cars) or (val == self.kind_cars)):
                 val = torch.tensor(self.kind_cars) - 0.5
             car = int(torch.floor(val).item())
             #car = np.random.randint(0, self.kind_cars)
             line = np.random.choice(self.possible_actions())
             self.add_to_buffer(car, line)
Пример #10
0
 def __init__(self,
              concentration,
              rate,
              validate_args=False,
              transform=None):
     TModule.__init__(self)
     Gamma.__init__(self,
                    concentration=concentration,
                    rate=rate,
                    validate_args=validate_args)
     _bufferize_attributes(self, ("concentration", "rate"))
     self._transform = transform
Пример #11
0
    def log_like(self, Y, dT):

        prob = []
        self.pi_norm = torch.nn.LogSoftmax(dim=0)(self.pi)
        for k in range(self.K):
            p_g = Gamma(torch.exp(self.log_ab[k][0]),
                        torch.exp(self.log_ab[k][1]))
            prob_g = p_g.log_prob(dT)
            prob.append(prob_g + self.pi_norm[k])
        prob = torch.stack(prob).t() + self.likelihood.mixture_prob(Y)

        return prob
Пример #12
0
def main():
    # TODO 1: add plot output
    epoch = 4001
    sampler = MarsagliaTsampler(size=1)
    optimizer = optim.SGD(sampler.parameters(), lr=0.001, momentum=0.9)

    alpha, beta = 10, 1
    target_distribution = Gamma(concentration=alpha, rate=beta)

    filenames = []
    path = '/extra/yadongl10/git_project/GammaLearningResult'
    for i in range(epoch):
        # compute loss
        samples = sampler(batch_size=128)
        samples = samples[samples > 0]

        loss = CE_Gamma(
            samples, target_distribution
        )  # For MarsagliaTsampler, currently only supports beta=1

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # print intermediet results
        if i % 500 == 0:
            print('loss {}'.format(loss))
            plt.plot(np.linspace(0.001, 25, 1000),
                     target_distribution.log_prob(
                         torch.linspace(0.001, 25, 1000)).exp().tolist(),
                     label='Target: Gamma({},{})'.format(alpha, beta))
            test_samples = sampler(batch_size=1000)
            test_samples = test_samples[test_samples > 0]
            plt.hist(test_samples.tolist(),
                     bins=30,
                     normed=True,
                     histtype='step')
            plt.xlim(0, 30)
            plt.ylim(0, 0.25)
            plt.legend()
            plt.title(
                'Histogram of Samples Generated from trained models at epoch: {}'
                .format(i))
            filenames.append('{}/epoch{}.png'.format(path, i))
            plt.savefig('{}/epoch{}.png'.format(path, i), dpi=200)
            plt.close()

    # make gif
    images = []
    for filename in filenames:
        images.append(imageio.imread(filename))
    imageio.mimwrite('{}/movie.gif'.format(path), images, duration=1)
Пример #13
0
    def test_gamma_prior_log_prob(self, cuda=False):
        device = torch.device("cuda") if cuda else torch.device("cpu")
        concentration = torch.tensor(1.0, device=device)
        rate = torch.tensor(1.0, device=device)
        prior = GammaPrior(concentration, rate)
        dist = Gamma(concentration, rate)

        t = torch.tensor(1.0, device=device)
        self.assertTrue(torch.equal(prior.log_prob(t), dist.log_prob(t)))
        t = torch.tensor([1.5, 0.5], device=device)
        self.assertTrue(torch.equal(prior.log_prob(t), dist.log_prob(t)))
        t = torch.tensor([[1.0, 0.5], [3.0, 0.25]], device=device)
        self.assertTrue(torch.equal(prior.log_prob(t), dist.log_prob(t)))
Пример #14
0
    def test_gamma_prior_log_prob_log_transform(self, cuda=False):
        device = torch.device("cuda") if cuda else torch.device("cpu")
        concentration = torch.tensor(1.0, device=device)
        rate = torch.tensor(1.0, device=device)
        prior = GammaPrior(concentration, rate, transform=torch.exp)
        dist = Gamma(concentration, rate)

        t = torch.tensor(0.0, device=device)
        self.assertTrue(torch.equal(prior.log_prob(t), dist.log_prob(t.exp())))
        t = torch.tensor([-1, 0.5], device=device)
        self.assertTrue(torch.equal(prior.log_prob(t), dist.log_prob(t.exp())))
        t = torch.tensor([[-1, 0.5], [0.1, -2.0]], device=device)
        self.assertTrue(torch.equal(prior.log_prob(t), dist.log_prob(t.exp())))
Пример #15
0
    def test_gamma_prior_batch_log_prob(self, cuda=False):
        device = torch.device("cuda") if cuda else torch.device("cpu")

        concentration = torch.tensor([1.0, 2.0], device=device)
        rate = torch.tensor([1.0, 2.0], device=device)
        prior = GammaPrior(concentration, rate)
        dist = Gamma(concentration, rate)
        t = torch.ones(2, device=device)
        self.assertTrue(torch.equal(prior.log_prob(t), dist.log_prob(t)))
        t = torch.ones(2, 2, device=device)
        self.assertTrue(torch.equal(prior.log_prob(t), dist.log_prob(t)))
        with self.assertRaises(RuntimeError):
            prior.log_prob(torch.ones(3, device=device))

        mean = torch.tensor([[1.0, 2.0], [0.5, 3.0]], device=device)
        variance = torch.tensor([[1.0, 2.0], [0.5, 1.0]], device=device)
        prior = GammaPrior(mean, variance)
        dist = Gamma(mean, variance)
        t = torch.ones(2, device=device)
        self.assertTrue(torch.equal(prior.log_prob(t), dist.log_prob(t)))
        t = torch.ones(2, 2, device=device)
        self.assertTrue(torch.equal(prior.log_prob(t), dist.log_prob(t)))
        with self.assertRaises(RuntimeError):
            prior.log_prob(torch.ones(3, device=device))
        with self.assertRaises(RuntimeError):
            prior.log_prob(torch.ones(2, 3, device=device))
Пример #16
0
 def __init__(self, data, mu_prior, alpha_prior, W_df_prior, W_prior,
              G_df_prior, rate_prior):
     d = W_prior.shape[0]
     self.W_prior = Wishart({
         'df':
         torch.tensor([W_df_prior], dtype=torch.float64),
         'W':
         torch.from_numpy(W_prior.astype(np.float64))
     })
     self.nu_prior = Gamma(torch.tensor([G_df_prior], dtype=torch.float64),
                           torch.tensor([rate_prior], dtype=torch.float64))
     self.mu_prior = MultivariateNormal(
         loc=mu_prior * torch.ones(d, dtype=torch.float64),
         covariance_matrix=alpha_prior * torch.eye(d, dtype=torch.float64))
     self.data = data
def sq_log_posterior_predictive_eval(x_new, kappa, tau_0, tau_1, S):
    T = kappa.shape[0] + 1
    q_beta = Beta(torch.ones(T - 1), kappa)
    q_lambda = Gamma(tau_0, tau_1)
    beta_mc = q_beta.sample([S])
    lambda_mc = q_lambda.sample([S])
    log_prob = 0
    for s in range(S):
        post_pred_weights = mix_weights(beta_mc[s])
        post_pred_clusters = lambda_mc[s]
        for t in range(post_pred_clusters.shape[0]):
            log_prob -= post_pred_weights[t] * torch.exp(
                Poisson(post_pred_clusters[t]).log_prob(x_new))**2
    log_prob /= S
    return log_prob
Пример #18
0
    def sample(self,
               num_samples:int,
               current_device: int, **kwargs) -> Tensor:
        """
        Samples from the latent space and return the corresponding
        image space map.
        :param num_samples: (Int) Number of samples
        :param current_device: (Int) Device to run the modelSay
        :return: (Tensor)
        """
        z = Gamma(self.prior_alpha, self.prior_beta).sample((num_samples, self.latent_dim))
        z = z.squeeze().to(current_device)

        samples = self.decode(z)
        return samples
Пример #19
0
    def kl(self, phi):
        p = 5e-3
        qz = self.z_mean.expand_as(self.theta)
        kl_z = qz * torch.log(qz / p) + (1 - qz) * torch.log((1 - qz) / (1 - p))
        qlogp = torch.log(nlp_pdf(self.signed_theta, phi, tau=0.358)+1e-8)

        if isinstance(self.alternative_sampler, MarsagliaTsampler):
            gamma = Gamma(concentration=self.detached_gamma_alpha, rate=1.)
            qlogq = np.log(0.5) + gamma.log_prob(self.theta)  # use unsigned self.theta to compute qlogq
        elif isinstance(self.alternative_sampler, LogNormalSampler):
            qlogq = self.log_prob_alternative(self.signed_theta)

        kl_beta = qlogq - qlogp
        kl = (kl_z + qz*kl_beta).sum(dim=1).mean()
        return kl, kl_z, kl_beta
Пример #20
0
def model(data, **kwargs):
    with pyro.plate("beta_plate", T - 1):
        beta = pyro.sample("beta", Beta(1, alpha))

    zeta = 2. * torch.ones(T * D, device=device)
    delta = 2. * torch.ones(T * D, device=device)
    with pyro.plate("prec_plate", T * D):
        prec = pyro.sample("prec", Gamma(zeta, delta))

    corr_chol = torch.zeros(T, D, D)
    for t in pyro.plate("corr_chol_plate", T):
        corr_chol[t, ...] = pyro.sample(
            "corr_chol_{}".format(t),
            LKJCorrCholesky(d=D, eta=torch.ones(1, device=device)))

    with pyro.plate("mu_plate", T):
        _std = torch.sqrt(1. / prec.view(-1, D))
        sigma_chol = torch.bmm(torch.diag_embed(_std), corr_chol)
        mu = pyro.sample(
            "mu",
            MultivariateNormal(torch.zeros(T, D, device=device),
                               scale_tril=sigma_chol))

    with pyro.plate("data", N):
        z = pyro.sample("z", Categorical(mix_weights(beta)))
        pyro.sample("obs",
                    MultivariateNormal(mu[z], scale_tril=sigma_chol[z]),
                    obs=data)
Пример #21
0
def draw_samples(path, name, model, test_loader, train_loader, method):
    sample = None
    if path.startswith('./assets/data/sbvae'):
        if method == 'km':
            sample = Beta(torch.tensor([1.0]),
                          torch.tensor([5.0])).rsample([64, 50
                                                        ]).squeeze().to(device)
        elif method == 'gamma':
            sample = Gamma(torch.tensor([1.0]),
                           torch.tensor([5.0
                                         ])).rsample([64,
                                                      50]).squeeze().to(device)
        cumprods = torch.cat((torch.ones(
            [64, 1], device=device), torch.cumprod(1 - sample, axis=1)),
                             dim=1)
        sample = cumprods[:, :-1] * sample
        sample[:, -1] = 1 - sample[:, :-1].sum(axis=1)
    else:
        sample = torch.randn(64, 50)
    sample = torch.sigmoid(model.decode(sample)).reshape(
        64, 28, 28).cpu().detach().numpy()
    f, axarr = plt.subplots(8, 8)
    for i in range(64):
        axarr[i // 8, i % 8].imshow(sample[i])
    plt.savefig(path + ".samples.pdf")
Пример #22
0
class WishartGammaNormalStudent:
    def __init__(self, data, mu_prior, alpha_prior, W_df_prior, W_prior,
                 G_df_prior, rate_prior):
        d = W_prior.shape[0]
        self.W_prior = Wishart({
            'df':
            torch.tensor([W_df_prior], dtype=torch.float64),
            'W':
            torch.from_numpy(W_prior.astype(np.float64))
        })
        self.nu_prior = Gamma(torch.tensor([G_df_prior], dtype=torch.float64),
                              torch.tensor([rate_prior], dtype=torch.float64))
        self.mu_prior = MultivariateNormal(
            loc=mu_prior * torch.ones(d, dtype=torch.float64),
            covariance_matrix=alpha_prior * torch.eye(d, dtype=torch.float64))
        self.data = data

    def log_prob(self, nuPmu):
        nu, P, mu = nuPmu
        if nu <= EPS:
            nu = torch.tensor([EPS], dtype=torch.float64)
        likelihood = MVtS({'df': nu, 'loc': mu, 'Sig': torch.inverse(P)}, P=P)
        return (torch.sum(likelihood.log_prob(self.data)) +
                self.W_prior.log_prob(P) + self.nu_prior.log_prob(nu) +
                self.mu_prior.log_prob(mu))

    def llh(self, nuPmu):
        nu, P, mu = nuPmu
        likelihood = MVtS({'df': nu, 'loc': mu, 'Sig': torch.inverse(P)}, P=P)
        return torch.sum(likelihood.log_prob(self.data))
Пример #23
0
 def __init__(self, concentration, rate, validate_args=None):
     base_dist = Gamma(concentration, rate)
     super().__init__(
         base_dist,
         PowerTransform(-base_dist.rate.new_ones(())),
         validate_args=validate_args,
     )
def posterior_predictive_sample(kappa, tau_0, tau_1, S, M):
    T = kappa.shape[0] + 1
    q_beta = Beta(torch.ones(T - 1), kappa)
    q_lambda = Gamma(tau_0, tau_1)
    beta_mc = q_beta.sample([S])
    lambda_mc = q_lambda.sample([S])

    hallucinated_samples = torch.zeros(S, M)
    for s in range(S):
        post_pred_weights = mix_weights(beta_mc[s])
        post_pred_clusters = lambda_mc[s]
        hallucinated_samples[s, :] = MixtureSameFamily(
            Categorical(post_pred_weights),
            Poisson(post_pred_clusters)).sample([M])

    return hallucinated_samples
Пример #25
0
    def gen_default_priors(self,
                           K,
                           L,
                           sig_prior=LogNormal(0, 1),
                           alpha_prior=Gamma(1., 1.),
                           mu0_prior=None,
                           mu1_prior=None,
                           W_prior=None,
                           eta0_prior=None,
                           eta1_prior=None):

        if L is None:
            L = [5, 5]

        self.L = L

        if K is None:
            K = 30

        self.K = K

        # FIXME: these should be an ordered prior of TruncatedNormals
        if mu0_prior is None:
            mu0_prior = Gamma(1, 1)

        if mu1_prior is None:
            mu1_prior = Gamma(1, 1)

        if W_prior is None:
            W_prior = Dirichlet(torch.ones(self.K) / self.K)

        if eta0_prior is None:
            eta0_prior = Dirichlet(torch.ones(self.L[0]) / self.L[0])

        if eta1_prior is None:
            eta1_prior = Dirichlet(torch.ones(self.L[1]) / self.L[1])

        self.priors = {
            'mu0': mu0_prior,
            'mu1': mu1_prior,
            'sig': sig_prior,
            'H': Normal(0, 1),
            'eta0': eta0_prior,
            'eta1': eta1_prior,
            'W': W_prior,
            'alpha': alpha_prior
        }
Пример #26
0
    def __init__(self,
                 dim,
                 act=nn.ReLU(),
                 num_hiddens=[50],
                 nout=1,
                 conf=dict()):
        nn.Module.__init__(self)
        BNN.__init__(self)
        self.dim = dim
        self.act = act
        self.num_hiddens = num_hiddens
        self.nout = nout
        self.steps_burnin = conf.get('steps_burnin', 2500)
        self.steps = conf.get('steps', 2500)
        self.keep_every = conf.get('keep_every', 50)
        self.batch_size = conf.get('batch_size', 32)
        self.warm_start = conf.get('warm_start', False)

        self.lr_weight = np.float32(conf.get('lr_weight', 1e-3))
        self.lr_noise = np.float32(conf.get('lr_noise', 1e-3))
        self.lr_lambda = np.float32(conf.get('lr_lambda', 1e-3))
        self.alpha_w = torch.as_tensor(1. * conf.get('alpha_w', 6.))
        self.beta_w = torch.as_tensor(1. * conf.get('beta_w', 6.))
        self.alpha_n = torch.as_tensor(1. * conf.get('alpha_n', 6.))
        self.beta_n = torch.as_tensor(1. * conf.get('beta_n', 6.))
        self.noise_level = conf.get('noise_level', None)
        if self.noise_level is not None:
            prec = 1 / self.noise_level**2
            prec_var = (prec * 0.25)**2
            self.beta_n = torch.as_tensor(prec / prec_var)
            self.alpha_n = torch.as_tensor(prec * self.beta_n)
            print("Reset alpha_n = %g, beta_n = %g" %
                  (self.alpha_n, self.beta_n))

        self.prior_log_lambda = TransformedDistribution(
            Gamma(self.alpha_w, self.beta_w),
            ExpTransform().inv)  # log of gamma distribution
        self.prior_log_precision = TransformedDistribution(
            Gamma(self.alpha_n, self.beta_n),
            ExpTransform().inv)

        self.log_lambda = nn.Parameter(torch.tensor(0.))
        self.log_precs = nn.Parameter(torch.zeros(self.nout))
        self.nn = NN(dim, self.act, self.num_hiddens, self.nout)

        self.init_nn()
Пример #27
0
    def generate(
        self,
        n_samples: int = 100,
        batch_size: int = 64
    ):  # with n_samples>1 return original list/ otherwise sequential
        """
        Return samples from posterior predictive. Proteins are concatenated to genes.

        :param n_samples: Number of posterior predictive samples
        :return: Tuple of posterior samples, original data
        """
        original_list = []
        posterior_list = []
        for tensors in self.update({"batch_size": batch_size}):
            x, _, _, batch_index, labels, y = tensors
            with torch.no_grad():
                outputs = self.model.inference(x,
                                               y,
                                               batch_index=batch_index,
                                               label=labels,
                                               n_samples=n_samples)
            px_ = outputs["px_"]
            py_ = outputs["py_"]

            pi = 1 / (1 + torch.exp(-py_["mixing"]))
            mixing_sample = Bernoulli(pi).sample()
            protein_rate = (py_["rate_fore"] * (1 - mixing_sample) +
                            py_["rate_back"] * mixing_sample)
            rate = torch.cat((px_["rate"], protein_rate), dim=-1)
            if len(px_["r"].size()) == 2:
                px_dispersion = px_["r"]
            else:
                px_dispersion = torch.ones_like(x) * px_["r"]
            if len(py_["r"].size()) == 2:
                py_dispersion = py_["r"]
            else:
                py_dispersion = torch.ones_like(y) * py_["r"]

            dispersion = torch.cat((px_dispersion, py_dispersion), dim=-1)

            # This gamma is really l*w using scVI manuscript notation
            p = rate / (rate + dispersion)
            r = dispersion
            l_train = Gamma(r, (1 - p) / p).sample()
            data = Poisson(l_train).sample().cpu().numpy()
            # """
            # In numpy (shape, scale) => (concentration, rate), with scale = p /(1 - p)
            # rate = (1 - p) / p  # = 1/scale # used in pytorch
            # """
            original_list += [np.array(torch.cat((x, y), dim=-1).cpu())]
            posterior_list += [data]

            posterior_list[-1] = np.transpose(posterior_list[-1], (1, 2, 0))

        return (
            np.concatenate(posterior_list, axis=0),
            np.concatenate(original_list, axis=0),
        )
Пример #28
0
    def log_like_dT(self, dT, log_ab):
        """
        Calculates per-class log likelihood of latent embeding and intervals
        Args:
            h (): latent embedding (N, D_h)
            dT (): interbout intervals (T)
            log_ab (): per-class parameters (K, 2, N)

        Returns: (N, K)

        """

        prob = []
        for k in range(self.K):
            p_g = Gamma(torch.exp(log_ab[k][0]), torch.exp(log_ab[k][1]))
            prob_g = p_g.log_prob(dT)
            prob.append(prob_g)
        prob = torch.stack(prob).t()

        return prob
Пример #29
0
    def gen_default_priors(self,
                           data,
                           K,
                           L,
                           sig_prior=Gamma(1, 1),
                           alpha_prior=Gamma(1, 1),
                           mu0_prior=None,
                           mu1_prior=None,
                           W_prior=None,
                           eta0_prior=None,
                           eta1_prior=None):

        if L is None:
            L = [5, 5]

        self.__cache_model_constants__(data, K, L)

        if mu0_prior is None:
            mu0_prior = Gamma(1, 1)

        if mu1_prior is None:
            mu1_prior = Gamma(1, 1)

        if W_prior is None:
            W_prior = Dirichlet(torch.ones(self.K) / self.K)

        if eta0_prior is None:
            eta0_prior = Dirichlet(torch.ones(self.L[0]) / self.L[0])

        if eta1_prior is None:
            eta1_prior = Dirichlet(torch.ones(self.L[1]) / self.L[1])

        self.priors = {
            'mu0': mu0_prior,
            'mu1': mu1_prior,
            'sig': sig_prior,
            'eta0': eta0_prior,
            'eta1': eta1_prior,
            'W': W_prior,
            'alpha': alpha_prior
        }
Пример #30
0
def test_log_prob(batch_shape, dim):
    loc = torch.randn(batch_shape + (dim, ))
    A = torch.randn(batch_shape + (dim, dim + dim))
    scale_tril = A.matmul(A.transpose(-2, -1)).cholesky()
    x = torch.randn(batch_shape + (dim, ))
    df = torch.randn(batch_shape).exp() + 2
    actual_log_prob = MultivariateStudentT(df, loc, scale_tril).log_prob(x)

    if dim == 1:
        expected_log_prob = StudentT(df.unsqueeze(-1), loc,
                                     scale_tril[..., 0]).log_prob(x).sum(-1)
        assert_equal(actual_log_prob, expected_log_prob)

    # test the fact MVT(df, loc, scale)(x) = int MVN(loc, scale / m)(x) Gamma(df/2,df/2)(m) dm
    num_samples = 100000
    gamma_samples = Gamma(df / 2, df / 2).sample(sample_shape=(num_samples, ))
    mvn_scale_tril = scale_tril / gamma_samples.sqrt().unsqueeze(-1).unsqueeze(
        -1)
    mvn = MultivariateNormal(loc, scale_tril=mvn_scale_tril)
    expected_log_prob = mvn.log_prob(x).logsumexp(0) - math.log(num_samples)
    assert_equal(actual_log_prob, expected_log_prob, prec=0.01)