コード例 #1
0
    def test_mixture_of_mvn(self):
        mu1 = np.asarray([0.0, 1.0])
        cov1 = np.diag([1.5, 2.5])
        mu2 = np.asarray([1.0, 0.0])
        cov2 = np.diag([2.5, 3.5])
        obs = np.asarray([[0.5, 0.5], mu1, mu2])
        with Model() as model:
            w = Dirichlet("w", floatX(np.ones(2)), transform=None, shape=(2,))
            mvncomp1 = MvNormal.dist(mu=mu1, cov=cov1)
            mvncomp2 = MvNormal.dist(mu=mu2, cov=cov2)
            y = Mixture("x_obs", w, [mvncomp1, mvncomp2], observed=obs)

        # check logp of each component
        complogp_st = np.vstack(
            (
                st.multivariate_normal.logpdf(obs, mu1, cov1),
                st.multivariate_normal.logpdf(obs, mu2, cov2),
            )
        ).T
        complogp = y.distribution._comp_logp(aesara.shared(obs)).eval()
        assert_allclose(complogp, complogp_st)

        # check logp of mixture
        testpoint = model.recompute_initial_point()
        mixlogp_st = logsumexp(np.log(testpoint["w"]) + complogp_st, axis=-1, keepdims=False)
        assert_allclose(y.logp_elemwise(testpoint), mixlogp_st)

        # check logp of model
        priorlogp = st.dirichlet.logpdf(
            x=testpoint["w"],
            alpha=np.ones(2),
        )
        assert_allclose(model.logp(testpoint), mixlogp_st.sum() + priorlogp)
コード例 #2
0
    def test_mixture_of_mixture(self):
        if aesara.config.floatX == "float32":
            rtol = 1e-4
        else:
            rtol = 1e-7
        nbr = 4
        with Model() as model:
            # mixtures components
            g_comp = Normal.dist(
                mu=Exponential("mu_g", lam=1.0, shape=nbr, transform=None), sigma=1, shape=nbr
            )
            l_comp = LogNormal.dist(
                mu=Exponential("mu_l", lam=1.0, shape=nbr, transform=None), sigma=1, shape=nbr
            )
            # weight vector for the mixtures
            g_w = Dirichlet("g_w", a=floatX(np.ones(nbr) * 0.0000001), transform=None, shape=(nbr,))
            l_w = Dirichlet("l_w", a=floatX(np.ones(nbr) * 0.0000001), transform=None, shape=(nbr,))
            # mixture components
            g_mix = Mixture.dist(w=g_w, comp_dists=g_comp)
            l_mix = Mixture.dist(w=l_w, comp_dists=l_comp)
            # mixture of mixtures
            mix_w = Dirichlet("mix_w", a=floatX(np.ones(2)), transform=None, shape=(2,))
            mix = Mixture("mix", w=mix_w, comp_dists=[g_mix, l_mix], observed=np.exp(self.norm_x))

        test_point = model.recompute_initial_point()

        def mixmixlogp(value, point):
            floatX = aesara.config.floatX
            priorlogp = (
                st.dirichlet.logpdf(
                    x=point["g_w"],
                    alpha=np.ones(nbr) * 0.0000001,
                ).astype(floatX)
                + st.expon.logpdf(x=point["mu_g"]).sum(dtype=floatX)
                + st.dirichlet.logpdf(
                    x=point["l_w"],
                    alpha=np.ones(nbr) * 0.0000001,
                ).astype(floatX)
                + st.expon.logpdf(x=point["mu_l"]).sum(dtype=floatX)
                + st.dirichlet.logpdf(
                    x=point["mix_w"],
                    alpha=np.ones(2),
                ).astype(floatX)
            )
            complogp1 = st.norm.logpdf(x=value, loc=point["mu_g"]).astype(floatX)
            mixlogp1 = logsumexp(
                np.log(point["g_w"]).astype(floatX) + complogp1, axis=-1, keepdims=True
            )
            complogp2 = st.lognorm.logpdf(value, 1.0, 0.0, np.exp(point["mu_l"])).astype(floatX)
            mixlogp2 = logsumexp(
                np.log(point["l_w"]).astype(floatX) + complogp2, axis=-1, keepdims=True
            )
            complogp_mix = np.concatenate((mixlogp1, mixlogp2), axis=1)
            mixmixlogpg = logsumexp(
                np.log(point["mix_w"]).astype(floatX) + complogp_mix, axis=-1, keepdims=False
            )
            return priorlogp, mixmixlogpg

        value = np.exp(self.norm_x)[:, None]
        priorlogp, mixmixlogpg = mixmixlogp(value, test_point)

        # check logp of mixture
        assert_allclose(mixmixlogpg, mix.logp_elemwise(test_point), rtol=rtol)

        # check model logp
        assert_allclose(priorlogp + mixmixlogpg.sum(), model.logp(test_point), rtol=rtol)

        # check input and check logp again
        test_point["g_w"] = np.asarray([0.1, 0.1, 0.2, 0.6])
        test_point["mu_g"] = np.exp(np.random.randn(nbr))
        priorlogp, mixmixlogpg = mixmixlogp(value, test_point)
        assert_allclose(mixmixlogpg, mix.logp_elemwise(test_point), rtol=rtol)
        assert_allclose(priorlogp + mixmixlogpg.sum(), model.logp(test_point), rtol=rtol)