def test_normal_mixture(self): with Model() as model: w = Dirichlet('w', np.ones_like(self.norm_w)) mu = Normal('mu', 0., 10., shape=self.norm_w.size) tau = Gamma('tau', 1., 1., shape=self.norm_w.size) x_obs = NormalMixture('x_obs', w, mu, tau=tau, observed=self.norm_x) step = Metropolis() trace = sample(5000, step, random_seed=self.random_seed, progressbar=False) assert_allclose(np.sort(trace['w'].mean(axis=0)), np.sort(self.norm_w), rtol=0.1, atol=0.1) assert_allclose(np.sort(trace['mu'].mean(axis=0)), np.sort(self.norm_mu), rtol=0.1, atol=0.1)
def test_normal_mixture_nd(self): nd, ncomp = 3, 5 with Model() as model0: mus = Normal('mus', shape=(nd, ncomp)) taus = Gamma('taus', alpha=1, beta=1, shape=(nd, ncomp)) ws = Dirichlet('ws', np.ones(ncomp)) mixture0 = NormalMixture('m', w=ws, mu=mus, tau=taus, shape=nd) with Model() as model1: mus = Normal('mus', shape=(nd, ncomp)) taus = Gamma('taus', alpha=1, beta=1, shape=(nd, ncomp)) ws = Dirichlet('ws', np.ones(ncomp)) comp_dist = [Normal.dist(mu=mus[:, i], tau=taus[:, i]) for i in range(ncomp)] mixture1 = Mixture('m', w=ws, comp_dists=comp_dist, shape=nd) testpoint = model0.test_point testpoint['mus'] = np.random.randn(nd, ncomp) assert_allclose(model0.logp(testpoint), model1.logp(testpoint)) assert_allclose(mixture0.logp(testpoint), mixture1.logp(testpoint))
def test_normal_mixture(self): with Model() as model: w = Dirichlet("w", floatX(np.ones_like(self.norm_w)), shape=self.norm_w.size) mu = Normal("mu", 0.0, 10.0, shape=self.norm_w.size) tau = Gamma("tau", 1.0, 1.0, shape=self.norm_w.size) NormalMixture("x_obs", w, mu, tau=tau, observed=self.norm_x) step = Metropolis() trace = sample(5000, step, random_seed=self.random_seed, progressbar=False, chains=1) assert_allclose(np.sort(trace["w"].mean(axis=0)), np.sort(self.norm_w), rtol=0.1, atol=0.1) assert_allclose( np.sort(trace["mu"].mean(axis=0)), np.sort(self.norm_mu), rtol=0.1, atol=0.1 )
def test_normal_mixture_nd(self): nd, ncomp = 3, 5 with Model() as model0: mus = Normal('mus', shape=(nd, ncomp)) taus = Gamma('taus', alpha=1, beta=1, shape=(nd, ncomp)) ws = Dirichlet('ws', np.ones(ncomp)) mixture0 = NormalMixture('m', w=ws, mu=mus, tau=taus, shape=nd) with Model() as model1: mus = Normal('mus', shape=(nd, ncomp)) taus = Gamma('taus', alpha=1, beta=1, shape=(nd, ncomp)) ws = Dirichlet('ws', np.ones(ncomp)) comp_dist = [ Normal.dist(mu=mus[:, i], tau=taus[:, i]) for i in range(ncomp) ] mixture1 = Mixture('m', w=ws, comp_dists=comp_dist, shape=nd) testpoint = model0.test_point testpoint['mus'] = np.random.randn(nd, ncomp) assert_allclose(model0.logp(testpoint), model1.logp(testpoint)) assert_allclose(mixture0.logp(testpoint), mixture1.logp(testpoint))
def test_normal_mixture_nd(self, nd, ncomp): nd = to_tuple(nd) ncomp = int(ncomp) comp_shape = nd + (ncomp, ) test_mus = np.random.randn(*comp_shape) test_taus = np.random.gamma(1, 1, size=comp_shape) observed = generate_normal_mixture_data(w=np.ones(ncomp) / ncomp, mu=test_mus, sd=1 / np.sqrt(test_taus), size=10) with Model() as model0: mus = Normal("mus", shape=comp_shape) taus = Gamma("taus", alpha=1, beta=1, shape=comp_shape) ws = Dirichlet("ws", np.ones(ncomp), shape=(ncomp, )) mixture0 = NormalMixture("m", w=ws, mu=mus, tau=taus, shape=nd, comp_shape=comp_shape) obs0 = NormalMixture("obs", w=ws, mu=mus, tau=taus, shape=nd, comp_shape=comp_shape, observed=observed) with Model() as model1: mus = Normal("mus", shape=comp_shape) taus = Gamma("taus", alpha=1, beta=1, shape=comp_shape) ws = Dirichlet("ws", np.ones(ncomp), shape=(ncomp, )) comp_dist = [ Normal.dist(mu=mus[..., i], tau=taus[..., i], shape=nd) for i in range(ncomp) ] mixture1 = Mixture("m", w=ws, comp_dists=comp_dist, shape=nd) obs1 = Mixture("obs", w=ws, comp_dists=comp_dist, shape=nd, observed=observed) with Model() as model2: # Expected to fail if comp_shape is not provided, # nd is multidim and it does not broadcast with ncomp. If by chance # it does broadcast, an error is raised if the mixture is given # observed data. # Furthermore, the Mixture will also raise errors when the observed # data is multidimensional but it does not broadcast well with # comp_dists. mus = Normal("mus", shape=comp_shape) taus = Gamma("taus", alpha=1, beta=1, shape=comp_shape) ws = Dirichlet("ws", np.ones(ncomp), shape=(ncomp, )) if len(nd) > 1: if nd[-1] != ncomp: with pytest.raises(ValueError): NormalMixture("m", w=ws, mu=mus, tau=taus, shape=nd) mixture2 = None else: mixture2 = NormalMixture("m", w=ws, mu=mus, tau=taus, shape=nd) else: mixture2 = NormalMixture("m", w=ws, mu=mus, tau=taus, shape=nd) observed_fails = False if len(nd) >= 1 and nd != (1, ): try: np.broadcast(np.empty(comp_shape), observed) except Exception: observed_fails = True if observed_fails: with pytest.raises(ValueError): NormalMixture("obs", w=ws, mu=mus, tau=taus, shape=nd, observed=observed) obs2 = None else: obs2 = NormalMixture("obs", w=ws, mu=mus, tau=taus, shape=nd, observed=observed) testpoint = model0.test_point testpoint["mus"] = test_mus testpoint["taus"] = test_taus assert_allclose(model0.logp(testpoint), model1.logp(testpoint)) assert_allclose(mixture0.logp(testpoint), mixture1.logp(testpoint)) assert_allclose(obs0.logp(testpoint), obs1.logp(testpoint)) if mixture2 is not None and obs2 is not None: assert_allclose(model0.logp(testpoint), model2.logp(testpoint)) if mixture2 is not None: assert_allclose(mixture0.logp(testpoint), mixture2.logp(testpoint)) if obs2 is not None: assert_allclose(obs0.logp(testpoint), obs2.logp(testpoint))
def test_normal_mixture_nd(self, nd, ncomp): nd = to_tuple(nd) ncomp = int(ncomp) comp_shape = nd + (ncomp,) test_mus = np.random.randn(*comp_shape) test_taus = np.random.gamma(1, 1, size=comp_shape) observed = generate_normal_mixture_data(w=np.ones(ncomp)/ncomp, mu=test_mus, sd=1/np.sqrt(test_taus), size=10) with Model() as model0: mus = Normal('mus', shape=comp_shape) taus = Gamma('taus', alpha=1, beta=1, shape=comp_shape) ws = Dirichlet('ws', np.ones(ncomp)) mixture0 = NormalMixture('m', w=ws, mu=mus, tau=taus, shape=nd, comp_shape=comp_shape) obs0 = NormalMixture('obs', w=ws, mu=mus, tau=taus, shape=nd, comp_shape=comp_shape, observed=observed) with Model() as model1: mus = Normal('mus', shape=comp_shape) taus = Gamma('taus', alpha=1, beta=1, shape=comp_shape) ws = Dirichlet('ws', np.ones(ncomp)) comp_dist = [Normal.dist(mu=mus[..., i], tau=taus[..., i], shape=nd) for i in range(ncomp)] mixture1 = Mixture('m', w=ws, comp_dists=comp_dist, shape=nd) obs1 = Mixture('obs', w=ws, comp_dists=comp_dist, shape=nd, observed=observed) with Model() as model2: # Expected to fail if comp_shape is not provided, # nd is multidim and it does not broadcast with ncomp. If by chance # it does broadcast, an error is raised if the mixture is given # observed data. # Furthermore, the Mixture will also raise errors when the observed # data is multidimensional but it does not broadcast well with # comp_dists. mus = Normal('mus', shape=comp_shape) taus = Gamma('taus', alpha=1, beta=1, shape=comp_shape) ws = Dirichlet('ws', np.ones(ncomp)) if len(nd) > 1: if nd[-1] != ncomp: with pytest.raises(ValueError): NormalMixture('m', w=ws, mu=mus, tau=taus, shape=nd) mixture2 = None else: mixture2 = NormalMixture('m', w=ws, mu=mus, tau=taus, shape=nd) else: mixture2 = NormalMixture('m', w=ws, mu=mus, tau=taus, shape=nd) observed_fails = False if len(nd) >= 1 and nd != (1,): try: np.broadcast(np.empty(comp_shape), observed) except Exception: observed_fails = True if observed_fails: with pytest.raises(ValueError): NormalMixture('obs', w=ws, mu=mus, tau=taus, shape=nd, observed=observed) obs2 = None else: obs2 = NormalMixture('obs', w=ws, mu=mus, tau=taus, shape=nd, observed=observed) testpoint = model0.test_point testpoint['mus'] = test_mus testpoint['taus'] = test_taus assert_allclose(model0.logp(testpoint), model1.logp(testpoint)) assert_allclose(mixture0.logp(testpoint), mixture1.logp(testpoint)) assert_allclose(obs0.logp(testpoint), obs1.logp(testpoint)) if mixture2 is not None and obs2 is not None: assert_allclose(model0.logp(testpoint), model2.logp(testpoint)) if mixture2 is not None: assert_allclose(mixture0.logp(testpoint), mixture2.logp(testpoint)) if obs2 is not None: assert_allclose(obs0.logp(testpoint), obs2.logp(testpoint))