def __call__(self, samples: tp.Dict[str, np.ndarray], verbose: bool) -> bool: idata = from_pyjags(samples) ess = az.ess(idata, var_names=self.variable_names) minimum_ess = min(value['data'] for key, value in ess.to_dict()['data_vars'].items()) rhat = az.rhat(idata, var_names=self.variable_names) maximum_rhat_deviation = max( abs(value['data'] - 1.0) for key, value in rhat.to_dict()['data_vars'].items()) if verbose: print(f'minimum ess = {minimum_ess}') print(f'maximum rhat deviation = {maximum_rhat_deviation}') return minimum_ess >= self.minimum_ess and \ maximum_rhat_deviation <= self.maximum_rhat_deviation
def track_glm_hierarchical_ess(self, init): with glm_hierarchical_model(): start, step = pm.init_nuts( init=init, chains=self.chains, progressbar=False, random_seed=123 ) t0 = time.time() trace = pm.sample( draws=self.draws, step=step, cores=4, chains=self.chains, start=start, random_seed=100, progressbar=False, compute_convergence_checks=False, ) tot = time.time() - t0 ess = float(az.ess(trace, var_names=["mu_a"])["mu_a"].values) return ess / tot
def track_marginal_mixture_model_ess(self, init): model, start = mixture_model() with model: _, step = pm.init_nuts( init=init, chains=self.chains, progressbar=False, random_seed=123 ) start = [{k: v for k, v in start.items()} for _ in range(self.chains)] t0 = time.time() trace = pm.sample( draws=self.draws, step=step, cores=4, chains=self.chains, start=start, random_seed=100, progressbar=False, compute_convergence_checks=False, ) tot = time.time() - t0 ess = az.ess(trace, var_names=["mu"])["mu"].values.min() # worst case return ess / tot
def run(steppers, p): steppers = set(steppers) traces = {} effn = {} runtimes = {} with pm.Model() as model: if USE_XY: x = pm.Flat("x") y = pm.Flat("y") mu = np.array([0.0, 0.0]) cov = np.array([[1.0, p], [p, 1.0]]) z = pm.MvNormal.dist(mu=mu, cov=cov, shape=(2, )).logp(tt.stack([x, y])) pot = pm.Potential("logp_xy", z) start = {"x": 0, "y": 0} else: mu = np.array([0.0, 0.0]) cov = np.array([[1.0, p], [p, 1.0]]) z = pm.MvNormal("z", mu=mu, cov=cov, shape=(2, )) start = {"z": [0, 0]} for step_cls in steppers: name = step_cls.__name__ t_start = time.time() mt = pm.sample(draws=10000, chains=16, step=step_cls(), start=start) runtimes[name] = time.time() - t_start print("{} samples across {} chains".format( len(mt) * mt.nchains, mt.nchains)) traces[name] = mt en = az.ess(mt) print(f"effective: {en}\r\n") if USE_XY: effn[name] = np.mean(en["x"]) / len(mt) / mt.nchains else: effn[name] = np.mean(en["z"]) / len(mt) / mt.nchains return traces, effn, runtimes
def loss_average_plot(): """Plot effective sample size as a function of n_particles.""" # Load experiments out = pickle.load(open("experiments/loss_averages/out.p", "rb")) samples = out['samples'] particle_schedule = out['particle_schedule'] # ESS ess = [[] for i in range(3)] for i in range(len(samples)): for run in samples[i]: ess[i].append(az.ess(run)) # Paper plot sns.set_context('talk') fig, (ax) = plt.subplots(1, 1) labels = ['ST-ABC', 'K2-ABC', 'W-ABC'] line_cols = ['seagreen', 'royalblue', 'darkorange'] for i in range(len(samples)): ax = sns.lineplot(x=particle_schedule, y=ess[i], ax=ax, linewidth=4, linestyle='--', label=labels[i], color=line_cols[i]) ax.autoscale(enable=True, axis='x', tight=True) ax.legend(frameon=False, handlelength=1.8) ax.set(xlabel='Number of particles') ax.set_ylabel('Effective sample size') ax.set(yscale="log") fig.set_size_inches(7, 7) fig.savefig('loss_average.png', bbox_inches='tight')
def test_neff(self): if hasattr(self, "min_n_eff"): n_eff = az.ess(self.trace[self.burn:]) for var in n_eff: npt.assert_array_less(self.min_n_eff, n_eff[var])
def normal_proposal(old_point): symmetric return Normal(old_point, sigma*torch.ones_like(old_point)).sample() tf= HarmonicTrialFunction(torch.ones(1)) n_walkers = 2 init_config = torch.ones(n_walkers, 1) results = metropolis_symmetric(tf, init_config, normal_proposal, num_walkers=n_walkers, num_steps=100000) dataset1 = az.convert_to_dataset(results.numpy()) dataset2 = az.convert_to_inference_data(results.numpy()) az.plot_ess(dataset2, kind = "local") plt.savefig("Local") az.plot_ess(dataset2, kind = "quantile") plt.savefig("quantile") az.plot_ess(dataset2, kind = "evolution") plt.savefig("Evolution") print( az.ess(dataset1).data_vars) # In the Output_of_run array we are using units of 1000. #Output_of_run = numpy.array([0.02366, 1.087, 3.579, 7.21, 11.32, 15.9, 20.19, 25.2, 29.98, 32.94, 36.67, 39.41, 38.68, 42.96, 44.4, 45.35, 44.83, 45.94, 43.73, 46.34, 44.69, 45.15, 41.88,41.41, 41.33, 41, 38.46, 38.3, 37.49, 36.02]) #y_data = Output_of_run #x_data = numpy.linspace(0.01, 3, 30) #plt.scatter(x_data, y_data, c='r', label='ess scatter') ##plt.plot(x_data, y_data, label='ess fit') #plt.xlabel('x') #plt.ylabel('y') #plt.title('ess vs sigma ') #plt.legend() #plt.show() #plt.savefig('ess vs sigma')
def sample( self, p0, Niter, ladder=None, Tmin=1, Tmax=None, Tskip=100, isave=1000, covUpdate=1000, weights={}, initialize_jump_proposal_kwargs={}, burn=10000, maxIter=None, thin=1, i0=0, neff=100000, write_cold_chains=False, writeHotChains=False, save_jump_stats=False, hotChain=False, n_cold_chains=2, ): """ Function to carry out PTMCMC sampling. @param p0: Initial parameter vector @param self.Niter: Number of iterations to use for T = 1 chain @param ladder: User defined temperature ladder @param Tmin: Minimum temperature in ladder (default=1) @param Tmax: Maximum temperature in ladder (default=None) @param Tskip: Number of steps between proposed temperature swaps (default=100) @param isave: Number of iterations before writing to file (default=1000) @param covUpdate: Number of iterations between AM covariance updates (default=1000) @param SCAMweight: Weight of SCAM jumps in overall jump cycle (default=20) @param AMweight: Weight of AM jumps in overall jump cycle (default=20) @param DEweight: Weight of DE jumps in overall jump cycle (default=20) @param NUTSweight: Weight of the NUTS jumps in jump cycle (default=20) @param MALAweight: Weight of the MALA jumps in jump cycle (default=20) @param HMCweight: Weight of the HMC jumps in jump cycle (default=20) @param HMCstepsize: Step-size of the HMC jumps (default=0.1) @param HMCsteps: Maximum number of steps in an HMC trajectory (default=300) @param burn: Burn in time (DE jumps added after this iteration) (default=10000) @param maxIter: Maximum number of iterations for high temperature chains (default=2*self.Niter) @param self.thin: Save every self.thin MCMC samples @param i0: Iteration to start MCMC (if i0 !=0, do not re-initialize) @param neff: Number of effective samples to collect before terminating """ # get maximum number of iteration if maxIter is None and self.MPIrank > 0: maxIter = 2 * Niter elif maxIter is None and self.MPIrank == 0: maxIter = Niter # set up arrays to store lnprob, lnlike and chain N = int(maxIter / thin) # if picking up from previous run, don't re-initialize if i0 == 0: self.initialize( Niter, ladder=ladder, Tmin=Tmin, Tmax=Tmax, Tskip=Tskip, isave=isave, covUpdate=covUpdate, weights=weights, initialize_jump_proposal_kwargs=initialize_jump_proposal_kwargs, burn=burn, maxIter=maxIter, thin=thin, i0=i0, neff=neff, writeHotChains=writeHotChains, write_cold_chains=write_cold_chains, hotChain=hotChain, n_cold_chains=n_cold_chains, ) self.jump_proposal_kwargs = {} self.weights = weights ### compute lnprob for initial point in chain ### # if resuming, just start with first point in chain if self.resume and self.resumeLength > 0: p0, lnlike0, lnprob0 = ( self.resumechain[0, :-4], self.resumechain[0, -3], self.resumechain[0, -4], ) else: # compute prior lp = self.logp(p0) if lp == float(-np.inf): lnprob0 = -np.inf lnlike0 = -np.inf else: lnlike0 = self.logl(p0) lnprob0 = 1 / self.temp * lnlike0 + lp # record first values self.updateChains(p0, lnlike0, lnprob0, i0, 0) self.comm.barrier() # start iterations self.tstart = time.time() runComplete = False Neff = 0 for i in range(n_cold_chains): print("chain %s" % i) iter = i0 for j in range(Niter - 1): iter += 1 accepted = 0 # call PTMCMCOneStep p0, lnlike0, lnprob0 = self.PTMCMCOneStep( p0, lnlike0, lnprob0, iter, i) # compute effective number of samples if iter % 10000 == 0 and iter > 2 * self.burn and self.MPIrank == 0: try: ### this will calculate the number of effective ### samples for each chain samples = np.expand_dims(self._chain[i, :iter - 1], axis=0) arviz_samples = az.convert_to_inference_data(samples) Neff = int( np.min(az.ess(arviz_samples).to_array().values)) print("\n {0} total samples".format(iter)) print("\n {0} effective samples".format(Neff)) except NameError: Neff = 0 pass # stop if reached effective number of samples if self.MPIrank == 0 and int(Neff) > self.neff: if self.verbose: print( "\nRun Complete with {0} effective samples".format( int(Neff))) break if self.MPIrank == 0 and runComplete: for jj in range(1, self.nchain): self.comm.send(runComplete, dest=jj, tag=55) # check for other chains if self.MPIrank > 0: runComplete = self.comm.Iprobe(source=0, tag=55) time.sleep(0.000001) # trick to get around return Result(self._chain, self._lnlike, self._lnprob, self.burn, self.n_cold_chains)
def _run_convergence_checks(self, idata: arviz.InferenceData, model): if not hasattr(idata, 'posterior'): msg = "No posterior samples. Unable to run convergence checks" warn = SamplerWarning(WarningType.BAD_PARAMS, msg, 'info', None, None, None) self._add_warnings([warn]) return if idata.posterior.sizes['chain'] == 1: msg = ("Only one chain was sampled, this makes it impossible to " "run some convergence checks") warn = SamplerWarning(WarningType.BAD_PARAMS, msg, 'info') self._add_warnings([warn]) return valid_name = [rv.name for rv in model.free_RVs + model.deterministics] varnames = [] for rv in model.free_RVs: rv_name = rv.name if is_transformed_name(rv_name): rv_name2 = get_untransformed_name(rv_name) rv_name = rv_name2 if rv_name2 in valid_name else rv_name if rv_name in idata.posterior: varnames.append(rv_name) self._ess = ess = arviz.ess(idata, var_names=varnames) self._rhat = rhat = arviz.rhat(idata, var_names=varnames) warnings = [] rhat_max = max(val.max() for val in rhat.values()) if rhat_max > 1.4: msg = ("The rhat statistic is larger than 1.4 for some " "parameters. The sampler did not converge.") warn = SamplerWarning(WarningType.CONVERGENCE, msg, 'error', extra=rhat) warnings.append(warn) elif rhat_max > 1.2: msg = ("The rhat statistic is larger than 1.2 for some " "parameters.") warn = SamplerWarning(WarningType.CONVERGENCE, msg, 'warn', extra=rhat) warnings.append(warn) elif rhat_max > 1.05: msg = ("The rhat statistic is larger than 1.05 for some " "parameters. This indicates slight problems during " "sampling.") warn = SamplerWarning(WarningType.CONVERGENCE, msg, 'info', extra=rhat) warnings.append(warn) eff_min = min(val.min() for val in ess.values()) sizes = idata.posterior.sizes n_samples = sizes['chain'] * sizes['draw'] if eff_min < 200 and n_samples >= 500: msg = ("The estimated number of effective samples is smaller than " "200 for some parameters.") warn = SamplerWarning(WarningType.CONVERGENCE, msg, 'error', extra=ess) warnings.append(warn) elif eff_min / n_samples < 0.1: msg = ("The number of effective samples is smaller than " "10% for some parameters.") warn = SamplerWarning(WarningType.CONVERGENCE, msg, 'warn', extra=ess) warnings.append(warn) elif eff_min / n_samples < 0.25: msg = ("The number of effective samples is smaller than " "25% for some parameters.") warn = SamplerWarning(WarningType.CONVERGENCE, msg, 'info', extra=ess) warnings.append(warn) self._add_warnings(warnings)
def _run_convergence_checks(self, idata: arviz.InferenceData, model): if not hasattr(idata, "posterior"): msg = "No posterior samples. Unable to run convergence checks" warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info", None, None, None) self._add_warnings([warn]) return if idata.posterior.sizes["chain"] == 1: msg = ("Only one chain was sampled, this makes it impossible to " "run some convergence checks") warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info") self._add_warnings([warn]) return elif idata.posterior.sizes["chain"] < 4: msg = ( "We recommend running at least 4 chains for robust computation of " "convergence diagnostics") warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info") self._add_warnings([warn]) return valid_name = [rv.name for rv in model.free_RVs + model.deterministics] varnames = [] for rv in model.free_RVs: rv_name = rv.name if is_transformed_name(rv_name): rv_name2 = get_untransformed_name(rv_name) rv_name = rv_name2 if rv_name2 in valid_name else rv_name if rv_name in idata.posterior: varnames.append(rv_name) self._ess = ess = arviz.ess(idata, var_names=varnames) self._rhat = rhat = arviz.rhat(idata, var_names=varnames) warnings = [] rhat_max = max(val.max() for val in rhat.values()) if rhat_max > 1.01: msg = ("The rhat statistic is larger than 1.01 for some " "parameters. This indicates problems during sampling. " "See https://arxiv.org/abs/1903.08008 for details") warn = SamplerWarning(WarningType.CONVERGENCE, msg, "info", extra=rhat) warnings.append(warn) eff_min = min(val.min() for val in ess.values()) eff_per_chain = eff_min / idata.posterior.sizes["chain"] if eff_per_chain < 100: msg = ( "The effective sample size per chain is smaller than 100 for some parameters. " " A higher number is needed for reliable rhat and ess computation. " "See https://arxiv.org/abs/1903.08008 for details") warn = SamplerWarning(WarningType.CONVERGENCE, msg, "error", extra=ess) warnings.append(warn) self._add_warnings(warnings)
import torch import matplotlib.pyplot as plt from qmc.mcmc import metropolis_symmetric, clip_mvnormal_proposal, normal_proposal from qmc.wavefunction import HarmonicTrialFunction import arviz as az #First we begin by sampling from a 1D scalar field. # We will use a simple gaussian with one parameter. # Infact, we will just the harmonic oscillator ansatz. #We also compute the effective sample size using az.ess() from arviz package. tf = HarmonicTrialFunction(torch.ones(1)) n_walkers = 2 init_config = torch.ones(n_walkers, 1) results = metropolis_symmetric(tf, init_config, normal_proposal, num_walkers=n_walkers, num_steps=10000) dataset = az.convert_to_dataset(results.numpy()) print(az.ess(dataset))
res_posterior_summary = json.load(f) res_posterior_summary = pd.DataFrame.from_records(res_posterior_summary, index="variable") res_posterior_summary.index.name = None print(res_posterior_summary) reference = (pd.read_csv(f"./reference_posterior_{env_name}.csv", index_col=0, float_precision="high").reset_index().astype(float)) # test arviz functions funcs = { "rhat_rank": lambda x: az.rhat(x, method="rank"), "rhat_raw": lambda x: az.rhat(x, method="identity"), "ess_bulk": lambda x: az.ess(x, method="bulk"), "ess_tail": lambda x: az.ess(x, method="tail"), "ess_mean": lambda x: az.ess(x, method="mean"), "ess_sd": lambda x: az.ess(x, method="sd"), "ess_median": lambda x: az.ess(x, method="median"), "ess_raw": lambda x: az.ess(x, method="identity"), "ess_quantile01": lambda x: az.ess(x, method="quantile", prob=0.01), "ess_quantile10": lambda x: az.ess(x, method="quantile", prob=0.1), "ess_quantile30": lambda x: az.ess(x, method="quantile", prob=0.3), "mcse_mean": lambda x: az.mcse(x, method="mean"), "mcse_sd": lambda x: az.mcse(x, method="sd"), "mcse_median": lambda x: az.mcse(x, method="quantile", prob=0.5), "mcse_quantile01": lambda x: az.mcse(x, method="quantile", prob=0.01), "mcse_quantile10": lambda x: az.mcse(x, method="quantile", prob=0.1), "mcse_quantile30": lambda x: az.mcse(x, method="quantile", prob=0.3), }
def compute_ess(chains): # shape N_walkers, N_samples return arviz.ess(chains.cpu().detach().numpy())
def collect_samples_and_stats( config: SimpleNamespace, model_cls: Type[BaseModel], all_ppl_details: List[PPLDetails], train_data: xr.Dataset, test_data: xr.Dataset, output_dir: str, ) -> Tuple[xr.Dataset, xr.Dataset]: """ :param confg: The benchmark configuration. :param model_cls: The model class :param ppl_details: For each ppl the the impl and inference classes etc. :param train_data: The training dataset. :param test_data: The held-out test dataset. :param output_dir: The directory for storing results. :returns: Two datasets: variable_metrics Coordinates: ppl, metric (n_eff, Rhat), others from model Data variables: from model other_metrics Coordinates: ppl, chain, draw, phase (compile, infer) Data variables: pll (ppl, chain, draw), timing (ppl, chain, phase) """ all_variable_metrics, all_pll, all_timing, all_names = [], [], [], [] all_samples, all_overall_neff, all_overall_neff_per_time = [], [], [] for pplobj in all_ppl_details: all_names.append(pplobj.name) rand = np.random.RandomState(pplobj.seed) LOGGER.info(f"Starting inference on `{pplobj.name}` with seed {pplobj.seed}") # first compile the PPL Implementation this involves two steps compile_t1 = time.time() # compile step 1: instantiate ppl inference object infer_obj = pplobj.inference_class(pplobj.impl_class, train_data.attrs) # compile step 2: call compile infer_obj.compile(seed=rand.randint(1, int(1e7)), **pplobj.compile_args) compile_time = time.time() - compile_t1 LOGGER.info(f"compiling on `{pplobj.name}` took {compile_time:.2f} secs") if infer_obj.is_adaptive: num_warmup = ( config.num_warmup if pplobj.num_warmup is None else pplobj.num_warmup ) if num_warmup > config.iterations: raise ValueError( f"num_warmup ({num_warmup}) should be less than iterations " f"({config.iterations})" ) else: if pplobj.num_warmup: raise ValueError( f"{pplobj.name} is not adaptive and does not accept a nonzero " "num_warmup as its parameter." ) else: num_warmup = 0 # then run inference for each trial trial_samples, trial_pll, trial_timing = [], [], [] for trialnum in range(config.trials): infer_t1 = time.time() samples = infer_obj.infer( data=train_data, iterations=config.iterations, num_warmup=num_warmup, seed=rand.randint(1, int(1e7)), **pplobj.infer_args, ) infer_time = time.time() - infer_t1 LOGGER.info(f"inference trial {trialnum} took {infer_time:.2f} secs") # Drop all NaN samples returned by PPL, which could happen when a PPL # needs to fill in missing warmup samples (e.g. Jags) valid_samples = samples.dropna("draw") if samples.sizes["draw"] != config.iterations: raise RuntimeError( f"Expected {config.iterations} samples, but {samples.sizes['draw']}" f" samples are returned by {pplobj.name}" ) # compute the pll per sample and then convert it to the actual pll over # cumulative samples persample_pll = model_cls.evaluate_posterior_predictive( valid_samples, test_data ) pll = np.logaddexp.accumulate(persample_pll) - np.log( np.arange(valid_samples.sizes["draw"]) + 1 ) LOGGER.info(f"PLL = {str(pll)}") trial_samples.append(samples) # After dropping the NaN samples, the length of PLL could be shorter than # iterations. So we'll need to pad the PLL list to make its length # consistent across different PPLs padded_pll = np.full(config.iterations, np.nan) padded_pll[valid_samples.draw.data] = pll trial_pll.append(padded_pll) trial_timing.append([compile_time, infer_time]) # finally, give the inference object an opportunity # to write additional diagnostics infer_obj.additional_diagnostics(output_dir, f"{pplobj.name}_{trialnum}") del infer_obj # concatenate the samples data from each trial together so we can compute metrics trial_samples_data = xr.concat( trial_samples, pd.Index(data=np.arange(config.trials), name="chain") ) # exclude warm up samples when calculating diagonostics trial_samples_no_warmup = trial_samples_data.isel(draw=slice(num_warmup, None)) neff_data = arviz.ess(trial_samples_no_warmup) rhat_data = arviz.rhat(trial_samples_no_warmup) LOGGER.info(f"Trials completed for {pplobj.name}") LOGGER.info("== n_eff ===") LOGGER.info(str(neff_data.data_vars)) LOGGER.info("== Rhat ===") LOGGER.info(str(rhat_data.data_vars)) # compute ess/time neff_df = neff_data.to_dataframe() overall_neff = [ neff_df.values.min(), np.median(neff_df.values), neff_df.values.max(), ] mean_inference_time = np.mean(np.array(trial_timing)[:, 1]) overall_neff_per_time = np.array(overall_neff) / mean_inference_time LOGGER.info("== overall n_eff [min, median, max]===") LOGGER.info(str(overall_neff)) LOGGER.info("== overall n_eff/s [min, median, max]===") LOGGER.info(str(overall_neff_per_time)) trial_variable_metrics_data = xr.concat( [neff_data, rhat_data], pd.Index(data=["n_eff", "Rhat"], name="metric") ) all_variable_metrics.append(trial_variable_metrics_data) all_pll.append(trial_pll) all_timing.append(trial_timing) all_samples.append(trial_samples_data) all_overall_neff.append(overall_neff) all_overall_neff_per_time.append(overall_neff_per_time) # merge the trial-level metrics at the PPL level all_variable_metrics_data = xr.concat( all_variable_metrics, pd.Index(data=all_names, name="ppl") ) all_other_metrics_data = xr.Dataset( { "timing": (["ppl", "chain", "phase"], all_timing), "pll": (["ppl", "chain", "draw"], all_pll), "overall_neff": (["ppl", "percentile"], all_overall_neff), "overall_neff_per_time": (["ppl", "percentile"], all_overall_neff_per_time), }, coords={ "ppl": np.array(all_names), "chain": np.arange(config.trials), "phase": np.array(["compile", "infer"]), "draw": np.arange(config.iterations), "percentile": np.array(["min", "median", "max"]), }, ) all_samples_data = xr.concat(all_samples, pd.Index(data=all_names, name="ppl")) model_cls.additional_metrics(output_dir, all_samples_data, train_data, test_data) LOGGER.info("all benchmark samples and metrics collected") # save the samples data only if requested if getattr(config, "save_samples", False): save_dataset(output_dir, "samples", all_samples_data) # write out the metrics save_dataset(output_dir, "diagnostics", all_variable_metrics_data) save_dataset(output_dir, "metrics", all_other_metrics_data) return all_variable_metrics_data, all_other_metrics_data
To help us understand these diagnostics we are going to create two _synthetic posteriors_. The first one is a sample from a uniform distribution. We generate it using SciPy and we call it `good_chains`. This is an example of a "good" sample because we are generating independent and identically distributed (iid) samples and ideally this is what we want to approximate the posterior. The second one is called `bad_chains`, and it will represent a poor sample from the posterior. `bad_chains` is a poor _sample_ for two reasons: * Values are not independent. On the contrary they are highly correlated, meaning that given any number at any position in the sequence we can compute the exact sequence of number both before and after the given number. Highly correlation is the opposite of independence. * Values are not identically distributed, as you will see we are creating and array of 2 columns, the first one with numbers from 0 to 0.5 and the second one from 0.5 to 1. good_chains = stats.uniform.rvs(0, 1,size=(2,500)) bad_chains = np.linspace(0, 1, 1000).reshape(2, -1) ## Effective Sample Size (ESS) When using sampling methods like MCMC is common to wonder if a particular sample is large enough to confidently compute what we want, like for example a parameter mean. Answering in terms of the number of samples is generally not a good idea as samples from MCMC methods will be autocorrelated and autocorrelation decrease the actual amount of information contained in a sample. Instead, a better idea is to estimate the **effective Sample Size**, this is the number of samples we would have if our sample were actually iid. Using ArviZ we can compute it using `az.ess(⋅)` az.ess(good_chains), az.ess(bad_chains) This is telling us that even when in both cases we have 1000 samples, `bad_chains` is somewhat equivalent to a iid sample of size $\approx 2$. While `good_chains` is $\approx 1000$. If you resample `good_chains` you will see that the effective sample size you get will be different for each sample. This is expected as the samples will not be exactly the same, they are after all samples. Nevertheless, on average, the value of effective sample size will be lower than the $N$ number of samples. Notice, however, that ESS could be in fact larger! When using the NUTS sampler value pf $ESS > N$ can happen for parameters which posterior distribution close to Gaussian and which are almost independent of other parameters. > As a general rule of thumb we recommend an `ess` greater than 50 per chain, otherwise the estimation of the `ess` itself and the estimation of $\hat R$ are most likely unreliable. Because MCMC methods can have difficulties with mixing, it is important to use between-chain information in computing the ESS. This is one reason to routinary run more than one chain when fitting a Bayesian model using MCMC methods. We can also compute the effective sample size using `az.summary(⋅)` az.summary(good_chains) As you can see `az.summary(⋅)` provides 4 values for `ESS`, mean, sd, bulk and tail. Even more if you check the arguments `method` of the `az.ess(⋅)` functions you will see the following options - "bulk" - "tail"
def collect_samples_and_stats( config: SimpleNamespace, model_cls: Type[BaseModel], all_ppl_details: List[PPLDetails], train_data: xr.Dataset, test_data: xr.Dataset, output_dir: str, ) -> Tuple[xr.Dataset, xr.Dataset]: """ :param confg: The benchmark configuration. :param model_cls: The model class :param ppl_details: For each ppl the the impl and inference classes etc. :param train_data: The training dataset. :param test_data: The held-out test dataset. :param output_dir: The directory for storing results. :returns: Two datasets: variable_metrics Coordinates: ppl, metric (n_eff, Rhat), others from model Data variables: from model other_metrics Coordinates: ppl, chain, draw, phase (compile, infer) Data variables: pll (ppl, chain, draw), timing (ppl, chain, phase) """ all_variable_metrics, all_pll, all_timing, all_names = [], [], [], [] all_samples, all_overall_neff, all_overall_neff_per_time = [], [], [] for pplobj in all_ppl_details: all_names.append(pplobj.name) rand = np.random.RandomState(pplobj.seed) LOGGER.info(f"Starting inference on `{pplobj.name}` with seed {pplobj.seed}") # first compile the PPL Implementation this involves two steps compile_t1 = time.time() # compile step 1: instantiate ppl inference object infer_obj = pplobj.inference_class(pplobj.impl_class, train_data.attrs) # compile step 2: call compile infer_obj.compile(seed=rand.randint(1, 1e7), **pplobj.compile_args) compile_time = time.time() - compile_t1 LOGGER.info(f"compiling on `{pplobj.name}` took {compile_time:.2f} secs") # then run inference for each trial trial_samples, trial_pll, trial_timing = [], [], [] for trialnum in range(config.trials): infer_t1 = time.time() samples = infer_obj.infer( data=train_data, num_samples=config.num_samples, seed=rand.randint(1, 1e7), **pplobj.infer_args, ) infer_time = time.time() - infer_t1 LOGGER.info(f"inference trial {trialnum} took {infer_time:.2f} secs") # compute the pll per sample and then convert it to the actual pll over # cumulative samples persample_pll = model_cls.evaluate_posterior_predictive(samples, test_data) pll = np.logaddexp.accumulate(persample_pll) - np.log( np.arange(config.num_samples) + 1 ) LOGGER.info(f"PLL = {str(pll)}") trial_samples.append(samples) trial_pll.append(pll) trial_timing.append([compile_time, infer_time]) # finally, give the inference object an opportunity # to write additional diagnostics infer_obj.additional_diagnostics(output_dir, f"{pplobj.name}_{trialnum}") del infer_obj # concatenate the samples data from each trial together so we can compute metrics trial_samples_data = xr.concat( trial_samples, pd.Index(data=np.arange(config.trials), name="chain") ) neff_data = arviz.ess(trial_samples_data) rhat_data = arviz.rhat(trial_samples_data) LOGGER.info(f"Trials completed for {pplobj.name}") LOGGER.info("== n_eff ===") LOGGER.info(str(neff_data.data_vars)) LOGGER.info("== Rhat ===") LOGGER.info(str(rhat_data.data_vars)) # compute ess/time neff_df = neff_data.to_dataframe() overall_neff = [ neff_df.values.min(), np.median(neff_df.values), neff_df.values.max(), ] mean_inference_time = np.mean(np.array(trial_timing)[:, 1]) overall_neff_per_time = np.array(overall_neff) / mean_inference_time LOGGER.info("== overall n_eff [min, median, max]===") LOGGER.info(str(overall_neff)) LOGGER.info("== overall n_eff/s [min, median, max]===") LOGGER.info(str(overall_neff_per_time)) trial_variable_metrics_data = xr.concat( [neff_data, rhat_data], pd.Index(data=["n_eff", "Rhat"], name="metric") ) all_variable_metrics.append(trial_variable_metrics_data) all_pll.append(trial_pll) all_timing.append(trial_timing) all_samples.append(trial_samples_data) all_overall_neff.append(overall_neff) all_overall_neff_per_time.append(overall_neff_per_time) # merge the trial-level metrics at the PPL level all_variable_metrics_data = xr.concat( all_variable_metrics, pd.Index(data=all_names, name="ppl") ) all_other_metrics_data = xr.Dataset( { "timing": (["ppl", "chain", "phase"], all_timing), "pll": (["ppl", "chain", "draw"], all_pll), "overall_neff": (["ppl", "percentile"], all_overall_neff), "overall_neff_per_time": (["ppl", "percentile"], all_overall_neff_per_time), }, coords={ "ppl": np.array(all_names), "chain": np.arange(config.trials), "phase": np.array(["compile", "infer"]), "draw": np.arange(config.num_samples), "percentile": np.array(["min", "median", "max"]), }, ) all_samples_data = xr.concat(all_samples, pd.Index(data=all_names, name="ppl")) model_cls.additional_metrics(output_dir, all_samples_data, train_data, test_data) LOGGER.info("all benchmark samples and metrics collected") # save the samples data only if requested if getattr(config, "save_samples", False): save_dataset(output_dir, "samples", all_samples_data) # write out thes metrics save_dataset(output_dir, "diagnostics", all_variable_metrics_data) save_dataset(output_dir, "metrics", all_other_metrics_data) return all_variable_metrics_data, all_other_metrics_data