Exemplo n.º 1
0
 def time_glm_hierarchical_init(self, init):
     """How long does it take to run the initialization."""
     with glm_hierarchical_model():
         pm.init_nuts(init=init,
                      chains=self.chains,
                      progressbar=False,
                      seeds=np.arange(self.chains))
Exemplo n.º 2
0
def check_exec_nuts_init(method):
    with pm.Model() as model:
        pm.Normal("a", mu=0, sigma=1, size=2)
        pm.HalfNormal("b", sigma=1)
    with model:
        start, _ = pm.init_nuts(init=method, n_init=10, seeds=[1])
        assert isinstance(start, list)
        assert len(start) == 1
        assert isinstance(start[0], dict)
        assert model.a.tag.value_var.name in start[0]
        assert model.b.tag.value_var.name in start[0]
        start, _ = pm.init_nuts(init=method, n_init=10, chains=2, seeds=[1, 2])
        assert isinstance(start, list)
        assert len(start) == 2
        assert isinstance(start[0], dict)
        assert model.a.tag.value_var.name in start[0]
        assert model.b.tag.value_var.name in start[0]
Exemplo n.º 3
0
 def track_glm_hierarchical_ess(self, init):
     with glm_hierarchical_model():
         start, step = pm.init_nuts(init=init,
                                    chains=self.chains,
                                    progressbar=False,
                                    seeds=np.arange(self.chains))
         t0 = time.time()
         idata = pm.sample(
             draws=self.draws,
             step=step,
             cores=4,
             chains=self.chains,
             start=start,
             seeds=np.arange(self.chains),
             progressbar=False,
             compute_convergence_checks=False,
         )
         tot = time.time() - t0
     ess = float(az.ess(idata, var_names=["mu_a"])["mu_a"].values)
     return ess / tot
Exemplo n.º 4
0
 def track_marginal_mixture_model_ess(self, init):
     model, start = mixture_model()
     with model:
         _, step = pm.init_nuts(init=init,
                                chains=self.chains,
                                progressbar=False,
                                seeds=np.arange(self.chains))
         start = [{k: v
                   for k, v in start.items()} for _ in range(self.chains)]
         t0 = time.time()
         idata = pm.sample(
             draws=self.draws,
             step=step,
             cores=4,
             chains=self.chains,
             start=start,
             seeds=np.arange(self.chains),
             progressbar=False,
             compute_convergence_checks=False,
         )
         tot = time.time() - t0
     ess = az.ess(idata, var_names=["mu"])["mu"].values.min()  # worst case
     return ess / tot