def test_dimensions(self): a1 = Normal.dist(mu=0, sigma=1) a2 = Normal.dist(mu=10, sigma=1) mix = Mixture.dist(w=np.r_[0.5, 0.5], comp_dists=[a1, a2]) assert mix.mode.ndim == 0 assert mix.logp(0.0).ndim == 0 value = np.r_[0.0, 1.0, 2.0] assert mix.logp(value).ndim == 1
def test_mixture_of_mixture(self): if aesara.config.floatX == "float32": rtol = 1e-4 else: rtol = 1e-7 nbr = 4 with Model() as model: # mixtures components g_comp = Normal.dist( mu=Exponential("mu_g", lam=1.0, shape=nbr, transform=None), sigma=1, shape=nbr ) l_comp = LogNormal.dist( mu=Exponential("mu_l", lam=1.0, shape=nbr, transform=None), sigma=1, shape=nbr ) # weight vector for the mixtures g_w = Dirichlet("g_w", a=floatX(np.ones(nbr) * 0.0000001), transform=None, shape=(nbr,)) l_w = Dirichlet("l_w", a=floatX(np.ones(nbr) * 0.0000001), transform=None, shape=(nbr,)) # mixture components g_mix = Mixture.dist(w=g_w, comp_dists=g_comp) l_mix = Mixture.dist(w=l_w, comp_dists=l_comp) # mixture of mixtures mix_w = Dirichlet("mix_w", a=floatX(np.ones(2)), transform=None, shape=(2,)) mix = Mixture("mix", w=mix_w, comp_dists=[g_mix, l_mix], observed=np.exp(self.norm_x)) test_point = model.recompute_initial_point() def mixmixlogp(value, point): floatX = aesara.config.floatX priorlogp = ( st.dirichlet.logpdf( x=point["g_w"], alpha=np.ones(nbr) * 0.0000001, ).astype(floatX) + st.expon.logpdf(x=point["mu_g"]).sum(dtype=floatX) + st.dirichlet.logpdf( x=point["l_w"], alpha=np.ones(nbr) * 0.0000001, ).astype(floatX) + st.expon.logpdf(x=point["mu_l"]).sum(dtype=floatX) + st.dirichlet.logpdf( x=point["mix_w"], alpha=np.ones(2), ).astype(floatX) ) complogp1 = st.norm.logpdf(x=value, loc=point["mu_g"]).astype(floatX) mixlogp1 = logsumexp( np.log(point["g_w"]).astype(floatX) + complogp1, axis=-1, keepdims=True ) complogp2 = st.lognorm.logpdf(value, 1.0, 0.0, np.exp(point["mu_l"])).astype(floatX) mixlogp2 = logsumexp( np.log(point["l_w"]).astype(floatX) + complogp2, axis=-1, keepdims=True ) complogp_mix = np.concatenate((mixlogp1, mixlogp2), axis=1) mixmixlogpg = logsumexp( np.log(point["mix_w"]).astype(floatX) + complogp_mix, axis=-1, keepdims=False ) return priorlogp, mixmixlogpg value = np.exp(self.norm_x)[:, None] priorlogp, mixmixlogpg = mixmixlogp(value, test_point) # check logp of mixture assert_allclose(mixmixlogpg, mix.logp_elemwise(test_point), rtol=rtol) # check model logp assert_allclose(priorlogp + mixmixlogpg.sum(), model.logp(test_point), rtol=rtol) # check input and check logp again test_point["g_w"] = np.asarray([0.1, 0.1, 0.2, 0.6]) test_point["mu_g"] = np.exp(np.random.randn(nbr)) priorlogp, mixmixlogpg = mixmixlogp(value, test_point) assert_allclose(mixmixlogpg, mix.logp_elemwise(test_point), rtol=rtol) assert_allclose(priorlogp + mixmixlogpg.sum(), model.logp(test_point), rtol=rtol)