Ejemplo n.º 1
0
def multivariate_gaussian_DPMM_isotropic(data: jnp.ndarray,
                                         alpha: float = 1,
                                         sigma: float = 0,
                                         T: int = 10):
    Npoints, Ndim = data.shape

    mu_bar, sigma2_mu = richardson_component_prior(data)
    assert mu_bar.shape == (Ndim, )
    assert isinstance(sigma2_mu, float)

    beta = sample_beta_PY(alpha=alpha, sigma=sigma, T=T)
    assert beta.shape == (T - 1, )

    with numpyro.plate("component_plate", T):
        mu = numpyro.sample(
            "mu", MultivariateNormal(mu_bar, sigma2_mu * np.eye(Ndim)))
        assert mu.shape == (T, Ndim), (mu.shape, T, Ndim)

        kappa = numpyro.sample("kappa", Gamma(2, sigma2_mu))
        assert kappa.shape == (T, ), (kappa.shape, T)

        # This line seems to make everything fail
        sigma2 = numpyro.sample("sigma2_inv", InverseGamma(.5, kappa))

        # variances = sigma2[:, None, None] * jnp.eye(Ndim)

    with numpyro.plate("data", Npoints):
        z = numpyro.sample("z", Categorical(mix_weights(beta)))

        # TODO use the actual variance here
        numpyro.sample("obs",
                       MultivariateNormal(mu[z], jnp.eye(Ndim)),
                       obs=data)
Ejemplo n.º 2
0
    def gibbs_fn(rng_key: random.PRNGKey, gibbs_sites: Dict[str, jnp.ndarray],
                 hmc_sites: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
        beta = hmc_sites['beta']
        mu = hmc_sites['mu']
        theta = hmc_sites['theta']
        L_omega = hmc_sites['L_omega']
        L_Omega = jnp.sqrt(theta.T[:, :, None]) * L_omega

        T, _ = mu.shape

        assert beta.shape == (T - 1, )
        assert mu.shape == (T, Ndim)
        assert theta.shape == (Ndim, T)
        assert L_omega.shape == (T, Ndim, Ndim)
        assert L_Omega.shape == (T, Ndim, Ndim)

        log_probs = MultivariateNormal(loc=mu,
                                       scale_tril=L_Omega).log_prob(data[:,
                                                                         None])
        assert log_probs.shape == (Npoints, T)

        log_weights = jnp.log(mix_weights(beta))
        assert log_weights.shape == (T, )

        logits = log_probs + log_weights[None, :]
        assert logits.shape == (Npoints, T)

        with numpyro.plate("z", Npoints):
            z = CategoricalLogits(logits).sample(rng_key)
        assert z.shape == (Npoints, )
        return {'z': z}
Ejemplo n.º 3
0
def multivariate_gaussian_DPMM(data: jnp.ndarray,
                               alpha: float = 1,
                               sigma: float = 0,
                               T: int = 10):
    Npoints, Ndim = data.shape
    mu_bar, sigma2_mu = richardson_component_prior(data)

    beta = sample_beta_PY(alpha=alpha, sigma=sigma, T=T)

    with numpyro.plate("component_plate", T):
        mu = numpyro.sample(
            "mu", MultivariateNormal(mu_bar, sigma2_mu * jnp.eye(Ndim)))

        # http://pyro.ai/examples/lkj.html
        with numpyro.plate("dim", Ndim):
            theta = numpyro.sample("theta", HalfCauchy(1))
        L_omega = numpyro.sample("L_omega", LKJCholesky(Ndim, 1))
        L_Omega = jnp.sqrt(theta.T[:, :, None]) * L_omega

    with numpyro.plate("data", Npoints):
        z = numpyro.sample("z", Categorical(mix_weights(beta)))

        assert mu.shape == (T, Ndim)
        assert L_Omega.shape == (T, Ndim, Ndim)

        numpyro.sample("obs",
                       MultivariateNormal(mu[z], scale_tril=L_Omega[z]),
                       obs=data)
Ejemplo n.º 4
0
    def gibbs_fn(rng_key: random.PRNGKey, gibbs_sites: Dict[str, jnp.ndarray],
                 hmc_sites: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
        beta = hmc_sites['beta']
        mu = hmc_sites['mu']
        sigma2 = hmc_sites['sigma2']

        T, = mu.shape
        assert beta.shape == (T - 1, )
        assert sigma2.shape == (T, )

        log_probs = Normal(loc=mu, scale=jnp.sqrt(sigma2)).log_prob(data[:,
                                                                         None])
        assert log_probs.shape == (Npoints, T)

        log_weights = jnp.log(mix_weights(beta))
        assert log_weights.shape == (T, )

        logits = log_probs + log_weights[None, :]
        assert logits.shape == (Npoints, T)

        with numpyro.plate("z", Npoints):
            z = CategoricalLogits(logits).sample(rng_key)
            assert z.shape == (Npoints, )

        return {'z': z}
Ejemplo n.º 5
0
 def forward(self, X):
     qz_x, alpha, _ = self.encode(X)
     pz = Independent(
         Beta(torch.ones_like(alpha),
              torch.ones_like(alpha) * self.prior_alpha), 1)
     pi = mix_weights(qz_x.rsample())[:, :-1]
     px_z = self.decode(pi)
     nll = -px_z.log_prob(X).mean()
     kl = kl_divergence(qz_x, pz).mean()
     return nll, kl
Ejemplo n.º 6
0
def poisson_DPMM(data: jnp.ndarray,
                 alpha: float = 1,
                 sigma: float = 0,
                 T: int = 10):
    beta = sample_beta_PY(alpha, sigma, T)

    with numpyro.plate("component_plate", T):
        rate = numpyro.sample("rate", Gamma(1, 1))

    with numpyro.plate("data", data.shape[0]):
        z = numpyro.sample("z", Categorical(mix_weights(beta)))
        numpyro.sample("obs", Poisson(rate[z]), obs=data)
Ejemplo n.º 7
0
def gaussian_DPMM(data: jnp.ndarray,
                  alpha: float = 1,
                  sigma: float = 0,
                  T: int = 10):
    Npoints, = data.shape
    mu_bar, sigma2_mu = richardson_component_prior(data)

    beta = sample_beta_PY(alpha=alpha, sigma=sigma, T=T)

    with numpyro.plate("component_plate", T):
        mu = numpyro.sample("mu", Normal(mu_bar, jnp.sqrt(sigma2_mu)))
        kappa = numpyro.sample("kappa", Gamma(2, sigma2_mu))
        sigma2 = numpyro.sample("sigma2", InverseGamma(.5, kappa))

    with numpyro.plate("data", Npoints):
        z = numpyro.sample("z", Categorical(mix_weights(beta)))
        numpyro.sample("obs", Normal(mu[z], jnp.sqrt(sigma2[z])), obs=data)
Ejemplo n.º 8
0
    def gibbs_fn(rng_key: random.PRNGKey, gibbs_sites, hmc_sites):
        rate = hmc_sites['rate']
        beta = hmc_sites['beta']

        T, = rate.shape
        assert beta.shape == (T - 1, )

        N, = data.shape

        log_probs = Poisson(rate).log_prob(data[:, None])
        assert log_probs.shape == (N, T)

        log_weights = jnp.log(mix_weights(beta))
        assert log_weights.shape == (T, )

        logits = log_probs + log_weights[None, :]
        assert logits.shape == (N, T)

        z = CategoricalLogits(logits).sample(rng_key)
        assert z.shape == (N, )

        return {'z': z}