def time_glm_hierarchical_init(self, init): """How long does it take to run the initialization.""" with glm_hierarchical_model(): pm.init_nuts(init=init, chains=self.chains, progressbar=False, seeds=np.arange(self.chains))
def check_exec_nuts_init(method): with pm.Model() as model: pm.Normal("a", mu=0, sigma=1, size=2) pm.HalfNormal("b", sigma=1) with model: start, _ = pm.init_nuts(init=method, n_init=10, seeds=[1]) assert isinstance(start, list) assert len(start) == 1 assert isinstance(start[0], dict) assert model.a.tag.value_var.name in start[0] assert model.b.tag.value_var.name in start[0] start, _ = pm.init_nuts(init=method, n_init=10, chains=2, seeds=[1, 2]) assert isinstance(start, list) assert len(start) == 2 assert isinstance(start[0], dict) assert model.a.tag.value_var.name in start[0] assert model.b.tag.value_var.name in start[0]
def track_glm_hierarchical_ess(self, init): with glm_hierarchical_model(): start, step = pm.init_nuts(init=init, chains=self.chains, progressbar=False, seeds=np.arange(self.chains)) t0 = time.time() idata = pm.sample( draws=self.draws, step=step, cores=4, chains=self.chains, start=start, seeds=np.arange(self.chains), progressbar=False, compute_convergence_checks=False, ) tot = time.time() - t0 ess = float(az.ess(idata, var_names=["mu_a"])["mu_a"].values) return ess / tot
def track_marginal_mixture_model_ess(self, init): model, start = mixture_model() with model: _, step = pm.init_nuts(init=init, chains=self.chains, progressbar=False, seeds=np.arange(self.chains)) start = [{k: v for k, v in start.items()} for _ in range(self.chains)] t0 = time.time() idata = pm.sample( draws=self.draws, step=step, cores=4, chains=self.chains, start=start, seeds=np.arange(self.chains), progressbar=False, compute_convergence_checks=False, ) tot = time.time() - t0 ess = az.ess(idata, var_names=["mu"])["mu"].values.min() # worst case return ess / tot