예제 #1
0
    def __call__(self, samples: tp.Dict[str, np.ndarray],
                 verbose: bool) -> bool:
        idata = from_pyjags(samples)
        ess = az.ess(idata, var_names=self.variable_names)

        minimum_ess = min(value['data']
                          for key, value in ess.to_dict()['data_vars'].items())

        rhat = az.rhat(idata, var_names=self.variable_names)

        maximum_rhat_deviation = max(
            abs(value['data'] - 1.0)
            for key, value in rhat.to_dict()['data_vars'].items())
        if verbose:
            print(f'minimum ess = {minimum_ess}')
            print(f'maximum rhat deviation = {maximum_rhat_deviation}')

        return minimum_ess >= self.minimum_ess and \
               maximum_rhat_deviation <= self.maximum_rhat_deviation
예제 #2
0
 def track_glm_hierarchical_ess(self, init):
     with glm_hierarchical_model():
         start, step = pm.init_nuts(
             init=init, chains=self.chains, progressbar=False, random_seed=123
         )
         t0 = time.time()
         trace = pm.sample(
             draws=self.draws,
             step=step,
             cores=4,
             chains=self.chains,
             start=start,
             random_seed=100,
             progressbar=False,
             compute_convergence_checks=False,
         )
         tot = time.time() - t0
     ess = float(az.ess(trace, var_names=["mu_a"])["mu_a"].values)
     return ess / tot
예제 #3
0
파일: benchmarks.py 프로젝트: gdupret/pymc3
 def track_marginal_mixture_model_ess(self, init):
     model, start = mixture_model()
     with model:
         _, step = pm.init_nuts(
             init=init, chains=self.chains, progressbar=False, random_seed=123
         )
         start = [{k: v for k, v in start.items()} for _ in range(self.chains)]
         t0 = time.time()
         trace = pm.sample(
             draws=self.draws,
             step=step,
             cores=4,
             chains=self.chains,
             start=start,
             random_seed=100,
             progressbar=False,
             compute_convergence_checks=False,
         )
         tot = time.time() - t0
     ess = az.ess(trace, var_names=["mu"])["mu"].values.min()  # worst case
     return ess / tot
예제 #4
0
def run(steppers, p):
    steppers = set(steppers)
    traces = {}
    effn = {}
    runtimes = {}

    with pm.Model() as model:
        if USE_XY:
            x = pm.Flat("x")
            y = pm.Flat("y")
            mu = np.array([0.0, 0.0])
            cov = np.array([[1.0, p], [p, 1.0]])
            z = pm.MvNormal.dist(mu=mu, cov=cov,
                                 shape=(2, )).logp(tt.stack([x, y]))
            pot = pm.Potential("logp_xy", z)
            start = {"x": 0, "y": 0}
        else:
            mu = np.array([0.0, 0.0])
            cov = np.array([[1.0, p], [p, 1.0]])
            z = pm.MvNormal("z", mu=mu, cov=cov, shape=(2, ))
            start = {"z": [0, 0]}

        for step_cls in steppers:
            name = step_cls.__name__
            t_start = time.time()
            mt = pm.sample(draws=10000,
                           chains=16,
                           step=step_cls(),
                           start=start)
            runtimes[name] = time.time() - t_start
            print("{} samples across {} chains".format(
                len(mt) * mt.nchains, mt.nchains))
            traces[name] = mt
            en = az.ess(mt)
            print(f"effective: {en}\r\n")
            if USE_XY:
                effn[name] = np.mean(en["x"]) / len(mt) / mt.nchains
            else:
                effn[name] = np.mean(en["z"]) / len(mt) / mt.nchains
    return traces, effn, runtimes
def loss_average_plot():
    """Plot effective sample size as a function of n_particles."""
    # Load experiments
    out = pickle.load(open("experiments/loss_averages/out.p", "rb"))
    samples = out['samples']
    particle_schedule = out['particle_schedule']

    # ESS
    ess = [[] for i in range(3)]
    for i in range(len(samples)):
        for run in samples[i]:
            ess[i].append(az.ess(run))

    # Paper plot
    sns.set_context('talk')
    fig, (ax) = plt.subplots(1, 1)

    labels = ['ST-ABC', 'K2-ABC', 'W-ABC']
    line_cols = ['seagreen', 'royalblue', 'darkorange']

    for i in range(len(samples)):
        ax = sns.lineplot(x=particle_schedule,
                          y=ess[i],
                          ax=ax,
                          linewidth=4,
                          linestyle='--',
                          label=labels[i],
                          color=line_cols[i])

    ax.autoscale(enable=True, axis='x', tight=True)
    ax.legend(frameon=False, handlelength=1.8)
    ax.set(xlabel='Number of particles')
    ax.set_ylabel('Effective sample size')
    ax.set(yscale="log")
    fig.set_size_inches(7, 7)
    fig.savefig('loss_average.png', bbox_inches='tight')
예제 #6
0
 def test_neff(self):
     if hasattr(self, "min_n_eff"):
         n_eff = az.ess(self.trace[self.burn:])
         for var in n_eff:
             npt.assert_array_less(self.min_n_eff, n_eff[var])
예제 #7
0
    def normal_proposal(old_point):
     symmetric
        return Normal(old_point, sigma*torch.ones_like(old_point)).sample()
    tf= HarmonicTrialFunction(torch.ones(1))
    n_walkers = 2
    init_config = torch.ones(n_walkers, 1)
    results = metropolis_symmetric(tf, init_config, normal_proposal, num_walkers=n_walkers, num_steps=100000)
    dataset1 = az.convert_to_dataset(results.numpy())
    dataset2 = az.convert_to_inference_data(results.numpy())


    az.plot_ess(dataset2, kind = "local")
    plt.savefig("Local")
    az.plot_ess(dataset2, kind = "quantile")
    plt.savefig("quantile")
    az.plot_ess(dataset2, kind = "evolution")
    plt.savefig("Evolution")
    print( az.ess(dataset1).data_vars)
# In the Output_of_run array we are using units of 1000.    
#Output_of_run = numpy.array([0.02366, 1.087, 3.579, 7.21, 11.32, 15.9, 20.19, 25.2, 29.98, 32.94, 36.67, 39.41, 38.68, 42.96, 44.4, 45.35, 44.83, 45.94, 43.73, 46.34, 44.69, 45.15, 41.88,41.41, 41.33, 41, 38.46, 38.3, 37.49, 36.02]) 
#y_data = Output_of_run
#x_data = numpy.linspace(0.01, 3, 30)
#plt.scatter(x_data, y_data, c='r', label='ess scatter')
##plt.plot(x_data, y_data, label='ess fit')
#plt.xlabel('x')
#plt.ylabel('y')
#plt.title('ess vs sigma ')
#plt.legend()
#plt.show()
#plt.savefig('ess vs sigma')
예제 #8
0
    def sample(
        self,
        p0,
        Niter,
        ladder=None,
        Tmin=1,
        Tmax=None,
        Tskip=100,
        isave=1000,
        covUpdate=1000,
        weights={},
        initialize_jump_proposal_kwargs={},
        burn=10000,
        maxIter=None,
        thin=1,
        i0=0,
        neff=100000,
        write_cold_chains=False,
        writeHotChains=False,
        save_jump_stats=False,
        hotChain=False,
        n_cold_chains=2,
    ):
        """
        Function to carry out PTMCMC sampling.

        @param p0: Initial parameter vector
        @param self.Niter: Number of iterations to use for T = 1 chain
        @param ladder: User defined temperature ladder
        @param Tmin: Minimum temperature in ladder (default=1)
        @param Tmax: Maximum temperature in ladder (default=None)
        @param Tskip: Number of steps between proposed temperature swaps (default=100)
        @param isave: Number of iterations before writing to file (default=1000)
        @param covUpdate: Number of iterations between AM covariance updates (default=1000)
        @param SCAMweight: Weight of SCAM jumps in overall jump cycle (default=20)
        @param AMweight: Weight of AM jumps in overall jump cycle (default=20)
        @param DEweight: Weight of DE jumps in overall jump cycle (default=20)
        @param NUTSweight: Weight of the NUTS jumps in jump cycle (default=20)
        @param MALAweight: Weight of the MALA jumps in jump cycle (default=20)
        @param HMCweight: Weight of the HMC jumps in jump cycle (default=20)
        @param HMCstepsize: Step-size of the HMC jumps (default=0.1)
        @param HMCsteps: Maximum number of steps in an HMC trajectory (default=300)
        @param burn: Burn in time (DE jumps added after this iteration) (default=10000)
        @param maxIter: Maximum number of iterations for high temperature chains
                        (default=2*self.Niter)
        @param self.thin: Save every self.thin MCMC samples
        @param i0: Iteration to start MCMC (if i0 !=0, do not re-initialize)
        @param neff: Number of effective samples to collect before terminating

        """

        # get maximum number of iteration
        if maxIter is None and self.MPIrank > 0:
            maxIter = 2 * Niter
        elif maxIter is None and self.MPIrank == 0:
            maxIter = Niter

        # set up arrays to store lnprob, lnlike and chain
        N = int(maxIter / thin)

        # if picking up from previous run, don't re-initialize
        if i0 == 0:
            self.initialize(
                Niter,
                ladder=ladder,
                Tmin=Tmin,
                Tmax=Tmax,
                Tskip=Tskip,
                isave=isave,
                covUpdate=covUpdate,
                weights=weights,
                initialize_jump_proposal_kwargs=initialize_jump_proposal_kwargs,
                burn=burn,
                maxIter=maxIter,
                thin=thin,
                i0=i0,
                neff=neff,
                writeHotChains=writeHotChains,
                write_cold_chains=write_cold_chains,
                hotChain=hotChain,
                n_cold_chains=n_cold_chains,
            )

        self.jump_proposal_kwargs = {}
        self.weights = weights

        ### compute lnprob for initial point in chain ###

        # if resuming, just start with first point in chain
        if self.resume and self.resumeLength > 0:
            p0, lnlike0, lnprob0 = (
                self.resumechain[0, :-4],
                self.resumechain[0, -3],
                self.resumechain[0, -4],
            )
        else:
            # compute prior
            lp = self.logp(p0)

            if lp == float(-np.inf):

                lnprob0 = -np.inf
                lnlike0 = -np.inf

            else:

                lnlike0 = self.logl(p0)
                lnprob0 = 1 / self.temp * lnlike0 + lp

        # record first values
        self.updateChains(p0, lnlike0, lnprob0, i0, 0)

        self.comm.barrier()

        # start iterations

        self.tstart = time.time()
        runComplete = False
        Neff = 0
        for i in range(n_cold_chains):
            print("chain %s" % i)
            iter = i0
            for j in range(Niter - 1):
                iter += 1
                accepted = 0

                # call PTMCMCOneStep
                p0, lnlike0, lnprob0 = self.PTMCMCOneStep(
                    p0, lnlike0, lnprob0, iter, i)

                # compute effective number of samples
                if iter % 10000 == 0 and iter > 2 * self.burn and self.MPIrank == 0:
                    try:
                        ### this will calculate the number of effective
                        ### samples for each chain
                        samples = np.expand_dims(self._chain[i, :iter - 1],
                                                 axis=0)
                        arviz_samples = az.convert_to_inference_data(samples)
                        Neff = int(
                            np.min(az.ess(arviz_samples).to_array().values))
                        print("\n {0} total samples".format(iter))
                        print("\n {0} effective samples".format(Neff))

                    except NameError:
                        Neff = 0
                        pass

                # stop if reached effective number of samples
                if self.MPIrank == 0 and int(Neff) > self.neff:
                    if self.verbose:
                        print(
                            "\nRun Complete with {0} effective samples".format(
                                int(Neff)))
                    break

                if self.MPIrank == 0 and runComplete:
                    for jj in range(1, self.nchain):
                        self.comm.send(runComplete, dest=jj, tag=55)

                # check for other chains
                if self.MPIrank > 0:
                    runComplete = self.comm.Iprobe(source=0, tag=55)
                    time.sleep(0.000001)  # trick to get around
        return Result(self._chain, self._lnlike, self._lnprob, self.burn,
                      self.n_cold_chains)
예제 #9
0
    def _run_convergence_checks(self, idata: arviz.InferenceData, model):
        if not hasattr(idata, 'posterior'):
            msg = "No posterior samples. Unable to run convergence checks"
            warn = SamplerWarning(WarningType.BAD_PARAMS, msg, 'info', None,
                                  None, None)
            self._add_warnings([warn])
            return

        if idata.posterior.sizes['chain'] == 1:
            msg = ("Only one chain was sampled, this makes it impossible to "
                   "run some convergence checks")
            warn = SamplerWarning(WarningType.BAD_PARAMS, msg, 'info')
            self._add_warnings([warn])
            return

        valid_name = [rv.name for rv in model.free_RVs + model.deterministics]
        varnames = []
        for rv in model.free_RVs:
            rv_name = rv.name
            if is_transformed_name(rv_name):
                rv_name2 = get_untransformed_name(rv_name)
                rv_name = rv_name2 if rv_name2 in valid_name else rv_name
            if rv_name in idata.posterior:
                varnames.append(rv_name)

        self._ess = ess = arviz.ess(idata, var_names=varnames)
        self._rhat = rhat = arviz.rhat(idata, var_names=varnames)

        warnings = []
        rhat_max = max(val.max() for val in rhat.values())
        if rhat_max > 1.4:
            msg = ("The rhat statistic is larger than 1.4 for some "
                   "parameters. The sampler did not converge.")
            warn = SamplerWarning(WarningType.CONVERGENCE,
                                  msg,
                                  'error',
                                  extra=rhat)
            warnings.append(warn)
        elif rhat_max > 1.2:
            msg = ("The rhat statistic is larger than 1.2 for some "
                   "parameters.")
            warn = SamplerWarning(WarningType.CONVERGENCE,
                                  msg,
                                  'warn',
                                  extra=rhat)
            warnings.append(warn)
        elif rhat_max > 1.05:
            msg = ("The rhat statistic is larger than 1.05 for some "
                   "parameters. This indicates slight problems during "
                   "sampling.")
            warn = SamplerWarning(WarningType.CONVERGENCE,
                                  msg,
                                  'info',
                                  extra=rhat)
            warnings.append(warn)

        eff_min = min(val.min() for val in ess.values())
        sizes = idata.posterior.sizes
        n_samples = sizes['chain'] * sizes['draw']
        if eff_min < 200 and n_samples >= 500:
            msg = ("The estimated number of effective samples is smaller than "
                   "200 for some parameters.")
            warn = SamplerWarning(WarningType.CONVERGENCE,
                                  msg,
                                  'error',
                                  extra=ess)
            warnings.append(warn)
        elif eff_min / n_samples < 0.1:
            msg = ("The number of effective samples is smaller than "
                   "10% for some parameters.")
            warn = SamplerWarning(WarningType.CONVERGENCE,
                                  msg,
                                  'warn',
                                  extra=ess)
            warnings.append(warn)
        elif eff_min / n_samples < 0.25:
            msg = ("The number of effective samples is smaller than "
                   "25% for some parameters.")
            warn = SamplerWarning(WarningType.CONVERGENCE,
                                  msg,
                                  'info',
                                  extra=ess)
            warnings.append(warn)

        self._add_warnings(warnings)
예제 #10
0
파일: report.py 프로젝트: t-triobox/pymc3
    def _run_convergence_checks(self, idata: arviz.InferenceData, model):
        if not hasattr(idata, "posterior"):
            msg = "No posterior samples. Unable to run convergence checks"
            warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info", None,
                                  None, None)
            self._add_warnings([warn])
            return

        if idata.posterior.sizes["chain"] == 1:
            msg = ("Only one chain was sampled, this makes it impossible to "
                   "run some convergence checks")
            warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info")
            self._add_warnings([warn])
            return

        elif idata.posterior.sizes["chain"] < 4:
            msg = (
                "We recommend running at least 4 chains for robust computation of "
                "convergence diagnostics")
            warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info")
            self._add_warnings([warn])
            return

        valid_name = [rv.name for rv in model.free_RVs + model.deterministics]
        varnames = []
        for rv in model.free_RVs:
            rv_name = rv.name
            if is_transformed_name(rv_name):
                rv_name2 = get_untransformed_name(rv_name)
                rv_name = rv_name2 if rv_name2 in valid_name else rv_name
            if rv_name in idata.posterior:
                varnames.append(rv_name)

        self._ess = ess = arviz.ess(idata, var_names=varnames)
        self._rhat = rhat = arviz.rhat(idata, var_names=varnames)

        warnings = []
        rhat_max = max(val.max() for val in rhat.values())
        if rhat_max > 1.01:
            msg = ("The rhat statistic is larger than 1.01 for some "
                   "parameters. This indicates problems during sampling. "
                   "See https://arxiv.org/abs/1903.08008 for details")
            warn = SamplerWarning(WarningType.CONVERGENCE,
                                  msg,
                                  "info",
                                  extra=rhat)
            warnings.append(warn)

        eff_min = min(val.min() for val in ess.values())
        eff_per_chain = eff_min / idata.posterior.sizes["chain"]
        if eff_per_chain < 100:
            msg = (
                "The effective sample size per chain is smaller than 100 for some parameters. "
                " A higher number is needed for reliable rhat and ess computation. "
                "See https://arxiv.org/abs/1903.08008 for details")
            warn = SamplerWarning(WarningType.CONVERGENCE,
                                  msg,
                                  "error",
                                  extra=ess)
            warnings.append(warn)

        self._add_warnings(warnings)
예제 #11
0
import torch
import matplotlib.pyplot as plt
from qmc.mcmc import metropolis_symmetric, clip_mvnormal_proposal, normal_proposal
from qmc.wavefunction import HarmonicTrialFunction
import arviz as az

#First we begin by sampling from a 1D scalar field.
# We will  use a simple gaussian with one parameter.
# Infact, we will just the harmonic oscillator ansatz.
#We also compute the effective sample size using az.ess() from arviz package.

tf = HarmonicTrialFunction(torch.ones(1))
n_walkers = 2
init_config = torch.ones(n_walkers, 1)
results = metropolis_symmetric(tf,
                               init_config,
                               normal_proposal,
                               num_walkers=n_walkers,
                               num_steps=10000)
dataset = az.convert_to_dataset(results.numpy())
print(az.ess(dataset))
    res_posterior_summary = json.load(f)

res_posterior_summary = pd.DataFrame.from_records(res_posterior_summary,
                                                  index="variable")
res_posterior_summary.index.name = None
print(res_posterior_summary)

reference = (pd.read_csv(f"./reference_posterior_{env_name}.csv",
                         index_col=0,
                         float_precision="high").reset_index().astype(float))

# test arviz functions
funcs = {
    "rhat_rank": lambda x: az.rhat(x, method="rank"),
    "rhat_raw": lambda x: az.rhat(x, method="identity"),
    "ess_bulk": lambda x: az.ess(x, method="bulk"),
    "ess_tail": lambda x: az.ess(x, method="tail"),
    "ess_mean": lambda x: az.ess(x, method="mean"),
    "ess_sd": lambda x: az.ess(x, method="sd"),
    "ess_median": lambda x: az.ess(x, method="median"),
    "ess_raw": lambda x: az.ess(x, method="identity"),
    "ess_quantile01": lambda x: az.ess(x, method="quantile", prob=0.01),
    "ess_quantile10": lambda x: az.ess(x, method="quantile", prob=0.1),
    "ess_quantile30": lambda x: az.ess(x, method="quantile", prob=0.3),
    "mcse_mean": lambda x: az.mcse(x, method="mean"),
    "mcse_sd": lambda x: az.mcse(x, method="sd"),
    "mcse_median": lambda x: az.mcse(x, method="quantile", prob=0.5),
    "mcse_quantile01": lambda x: az.mcse(x, method="quantile", prob=0.01),
    "mcse_quantile10": lambda x: az.mcse(x, method="quantile", prob=0.1),
    "mcse_quantile30": lambda x: az.mcse(x, method="quantile", prob=0.3),
}
예제 #13
0
def compute_ess(chains):
    # shape N_walkers, N_samples
    return arviz.ess(chains.cpu().detach().numpy())
예제 #14
0
def collect_samples_and_stats(
    config: SimpleNamespace,
    model_cls: Type[BaseModel],
    all_ppl_details: List[PPLDetails],
    train_data: xr.Dataset,
    test_data: xr.Dataset,
    output_dir: str,
) -> Tuple[xr.Dataset, xr.Dataset]:
    """
    :param confg: The benchmark configuration.
    :param model_cls: The model class
    :param ppl_details: For each ppl the the impl and inference classes etc.
    :param train_data: The training dataset.
    :param test_data: The held-out test dataset.
    :param output_dir: The directory for storing results.
    :returns: Two datasets:
        variable_metrics
            Coordinates: ppl, metric (n_eff, Rhat), others from model
            Data variables: from model
        other_metrics
            Coordinates: ppl, chain, draw, phase (compile, infer)
            Data variables: pll (ppl, chain, draw), timing (ppl, chain, phase)
    """
    all_variable_metrics, all_pll, all_timing, all_names = [], [], [], []
    all_samples, all_overall_neff, all_overall_neff_per_time = [], [], []
    for pplobj in all_ppl_details:
        all_names.append(pplobj.name)
        rand = np.random.RandomState(pplobj.seed)
        LOGGER.info(f"Starting inference on `{pplobj.name}` with seed {pplobj.seed}")
        # first compile the PPL Implementation this involves two steps
        compile_t1 = time.time()
        # compile step 1: instantiate ppl inference object
        infer_obj = pplobj.inference_class(pplobj.impl_class, train_data.attrs)
        # compile step 2: call compile
        infer_obj.compile(seed=rand.randint(1, int(1e7)), **pplobj.compile_args)
        compile_time = time.time() - compile_t1
        LOGGER.info(f"compiling on `{pplobj.name}` took {compile_time:.2f} secs")

        if infer_obj.is_adaptive:
            num_warmup = (
                config.num_warmup if pplobj.num_warmup is None else pplobj.num_warmup
            )
            if num_warmup > config.iterations:
                raise ValueError(
                    f"num_warmup ({num_warmup}) should be less than iterations "
                    f"({config.iterations})"
                )
        else:
            if pplobj.num_warmup:
                raise ValueError(
                    f"{pplobj.name} is not adaptive and does not accept a nonzero "
                    "num_warmup as its parameter."
                )
            else:
                num_warmup = 0

        # then run inference for each trial
        trial_samples, trial_pll, trial_timing = [], [], []
        for trialnum in range(config.trials):
            infer_t1 = time.time()
            samples = infer_obj.infer(
                data=train_data,
                iterations=config.iterations,
                num_warmup=num_warmup,
                seed=rand.randint(1, int(1e7)),
                **pplobj.infer_args,
            )
            infer_time = time.time() - infer_t1
            LOGGER.info(f"inference trial {trialnum} took {infer_time:.2f} secs")

            # Drop all NaN samples returned by PPL, which could happen when a PPL
            # needs to fill in missing warmup samples (e.g. Jags)
            valid_samples = samples.dropna("draw")
            if samples.sizes["draw"] != config.iterations:
                raise RuntimeError(
                    f"Expected {config.iterations} samples, but {samples.sizes['draw']}"
                    f" samples are returned by {pplobj.name}"
                )

            # compute the pll per sample and then convert it to the actual pll over
            # cumulative samples
            persample_pll = model_cls.evaluate_posterior_predictive(
                valid_samples, test_data
            )
            pll = np.logaddexp.accumulate(persample_pll) - np.log(
                np.arange(valid_samples.sizes["draw"]) + 1
            )
            LOGGER.info(f"PLL = {str(pll)}")
            trial_samples.append(samples)

            # After dropping the NaN samples, the length of PLL could be shorter than
            # iterations. So we'll need to pad the PLL list to make its length
            # consistent across different PPLs
            padded_pll = np.full(config.iterations, np.nan)
            padded_pll[valid_samples.draw.data] = pll
            trial_pll.append(padded_pll)
            trial_timing.append([compile_time, infer_time])
            # finally, give the inference object an opportunity
            # to write additional diagnostics
            infer_obj.additional_diagnostics(output_dir, f"{pplobj.name}_{trialnum}")
        del infer_obj
        # concatenate the samples data from each trial together so we can compute metrics
        trial_samples_data = xr.concat(
            trial_samples, pd.Index(data=np.arange(config.trials), name="chain")
        )
        # exclude warm up samples when calculating diagonostics
        trial_samples_no_warmup = trial_samples_data.isel(draw=slice(num_warmup, None))

        neff_data = arviz.ess(trial_samples_no_warmup)
        rhat_data = arviz.rhat(trial_samples_no_warmup)
        LOGGER.info(f"Trials completed for {pplobj.name}")
        LOGGER.info("== n_eff ===")
        LOGGER.info(str(neff_data.data_vars))
        LOGGER.info("==  Rhat ===")
        LOGGER.info(str(rhat_data.data_vars))

        # compute ess/time
        neff_df = neff_data.to_dataframe()
        overall_neff = [
            neff_df.values.min(),
            np.median(neff_df.values),
            neff_df.values.max(),
        ]
        mean_inference_time = np.mean(np.array(trial_timing)[:, 1])
        overall_neff_per_time = np.array(overall_neff) / mean_inference_time

        LOGGER.info("== overall n_eff [min, median, max]===")
        LOGGER.info(str(overall_neff))
        LOGGER.info("== overall n_eff/s [min, median, max]===")
        LOGGER.info(str(overall_neff_per_time))

        trial_variable_metrics_data = xr.concat(
            [neff_data, rhat_data], pd.Index(data=["n_eff", "Rhat"], name="metric")
        )
        all_variable_metrics.append(trial_variable_metrics_data)
        all_pll.append(trial_pll)
        all_timing.append(trial_timing)
        all_samples.append(trial_samples_data)
        all_overall_neff.append(overall_neff)
        all_overall_neff_per_time.append(overall_neff_per_time)
    # merge the trial-level metrics at the PPL level
    all_variable_metrics_data = xr.concat(
        all_variable_metrics, pd.Index(data=all_names, name="ppl")
    )
    all_other_metrics_data = xr.Dataset(
        {
            "timing": (["ppl", "chain", "phase"], all_timing),
            "pll": (["ppl", "chain", "draw"], all_pll),
            "overall_neff": (["ppl", "percentile"], all_overall_neff),
            "overall_neff_per_time": (["ppl", "percentile"], all_overall_neff_per_time),
        },
        coords={
            "ppl": np.array(all_names),
            "chain": np.arange(config.trials),
            "phase": np.array(["compile", "infer"]),
            "draw": np.arange(config.iterations),
            "percentile": np.array(["min", "median", "max"]),
        },
    )
    all_samples_data = xr.concat(all_samples, pd.Index(data=all_names, name="ppl"))
    model_cls.additional_metrics(output_dir, all_samples_data, train_data, test_data)
    LOGGER.info("all benchmark samples and metrics collected")
    # save the samples data only if requested
    if getattr(config, "save_samples", False):
        save_dataset(output_dir, "samples", all_samples_data)
    # write out the metrics
    save_dataset(output_dir, "diagnostics", all_variable_metrics_data)
    save_dataset(output_dir, "metrics", all_other_metrics_data)
    return all_variable_metrics_data, all_other_metrics_data
To help us understand these diagnostics we are going to create two _synthetic posteriors_. The first one is a sample from a uniform distribution. We generate it using SciPy and we call it `good_chains`. This is an example of a "good" sample because we are generating independent and identically distributed (iid) samples and ideally this is what we want to approximate the posterior. The second one is called `bad_chains`, and it will represent a poor sample from the posterior. `bad_chains` is a poor _sample_ for two reasons:

* Values are not independent. On the contrary they are highly correlated, meaning that given any number at any position in the sequence we can compute the exact sequence of number both before and after the given number. Highly correlation is the opposite of independence.
* Values are not identically distributed, as you will see we are creating and array of 2 columns, the first one with numbers from 0 to 0.5 and the second one from 0.5 to 1.

good_chains = stats.uniform.rvs(0, 1,size=(2,500))
bad_chains = np.linspace(0, 1, 1000).reshape(2, -1)

## Effective Sample Size (ESS)

When using sampling methods like MCMC is common to wonder if a particular sample is large enough to confidently compute what we want, like for example a parameter mean. Answering in terms of the number of samples is generally not a good idea as samples from MCMC methods will be autocorrelated and autocorrelation decrease the actual amount of information contained in a sample. Instead, a better idea is to estimate the **effective Sample Size**, this is the number of samples we would have if our sample were actually iid. 

Using ArviZ we can compute it using `az.ess(⋅)`

az.ess(good_chains), az.ess(bad_chains)

This is telling us that even when in both cases we have 1000 samples, `bad_chains` is somewhat equivalent to a iid sample of size $\approx 2$. While `good_chains` is $\approx 1000$. If you resample `good_chains` you will see that the effective sample size you get will be different for each sample. This is expected as the samples will not be exactly the same, they are after all samples. Nevertheless, on average, the value of effective sample size will be lower than the $N$ number of samples. Notice, however, that ESS could be in fact larger! When using the NUTS sampler value pf $ESS > N$ can happen for parameters which posterior distribution close to Gaussian and which are almost independent of other parameters.

> As a general rule of thumb we recommend an `ess` greater than 50 per chain, otherwise the estimation of the `ess` itself and the estimation of $\hat R$ are most likely unreliable.

Because MCMC methods can have difficulties with mixing, it is important to use between-chain information in computing the ESS. This is one reason to routinary run more than one chain when fitting a Bayesian model using MCMC methods.

We can also compute the effective sample size using `az.summary(⋅)`

az.summary(good_chains)

As you can see `az.summary(⋅)` provides 4 values for `ESS`, mean, sd, bulk and tail. Even more if you check the arguments `method` of the `az.ess(⋅)`  functions you will see the following options

- "bulk"
- "tail"
예제 #16
0
def collect_samples_and_stats(
    config: SimpleNamespace,
    model_cls: Type[BaseModel],
    all_ppl_details: List[PPLDetails],
    train_data: xr.Dataset,
    test_data: xr.Dataset,
    output_dir: str,
) -> Tuple[xr.Dataset, xr.Dataset]:
    """
    :param confg: The benchmark configuration.
    :param model_cls: The model class
    :param ppl_details: For each ppl the the impl and inference classes etc.
    :param train_data: The training dataset.
    :param test_data: The held-out test dataset.
    :param output_dir: The directory for storing results.
    :returns: Two datasets:
        variable_metrics
            Coordinates: ppl, metric (n_eff, Rhat), others from model
            Data variables: from model
        other_metrics
            Coordinates: ppl, chain, draw, phase (compile, infer)
            Data variables: pll (ppl, chain, draw), timing (ppl, chain, phase)
    """
    all_variable_metrics, all_pll, all_timing, all_names = [], [], [], []
    all_samples, all_overall_neff, all_overall_neff_per_time = [], [], []
    for pplobj in all_ppl_details:
        all_names.append(pplobj.name)
        rand = np.random.RandomState(pplobj.seed)
        LOGGER.info(f"Starting inference on `{pplobj.name}` with seed {pplobj.seed}")
        # first compile the PPL Implementation this involves two steps
        compile_t1 = time.time()
        # compile step 1: instantiate ppl inference object
        infer_obj = pplobj.inference_class(pplobj.impl_class, train_data.attrs)
        # compile step 2: call compile
        infer_obj.compile(seed=rand.randint(1, 1e7), **pplobj.compile_args)
        compile_time = time.time() - compile_t1
        LOGGER.info(f"compiling on `{pplobj.name}` took {compile_time:.2f} secs")
        # then run inference for each trial
        trial_samples, trial_pll, trial_timing = [], [], []
        for trialnum in range(config.trials):
            infer_t1 = time.time()
            samples = infer_obj.infer(
                data=train_data,
                num_samples=config.num_samples,
                seed=rand.randint(1, 1e7),
                **pplobj.infer_args,
            )
            infer_time = time.time() - infer_t1
            LOGGER.info(f"inference trial {trialnum} took {infer_time:.2f} secs")
            # compute the pll per sample and then convert it to the actual pll over
            # cumulative samples
            persample_pll = model_cls.evaluate_posterior_predictive(samples, test_data)
            pll = np.logaddexp.accumulate(persample_pll) - np.log(
                np.arange(config.num_samples) + 1
            )
            LOGGER.info(f"PLL = {str(pll)}")
            trial_samples.append(samples)
            trial_pll.append(pll)
            trial_timing.append([compile_time, infer_time])
            # finally, give the inference object an opportunity
            # to write additional diagnostics
            infer_obj.additional_diagnostics(output_dir, f"{pplobj.name}_{trialnum}")
        del infer_obj
        # concatenate the samples data from each trial together so we can compute metrics
        trial_samples_data = xr.concat(
            trial_samples, pd.Index(data=np.arange(config.trials), name="chain")
        )
        neff_data = arviz.ess(trial_samples_data)
        rhat_data = arviz.rhat(trial_samples_data)
        LOGGER.info(f"Trials completed for {pplobj.name}")
        LOGGER.info("== n_eff ===")
        LOGGER.info(str(neff_data.data_vars))
        LOGGER.info("==  Rhat ===")
        LOGGER.info(str(rhat_data.data_vars))

        # compute ess/time
        neff_df = neff_data.to_dataframe()
        overall_neff = [
            neff_df.values.min(),
            np.median(neff_df.values),
            neff_df.values.max(),
        ]
        mean_inference_time = np.mean(np.array(trial_timing)[:, 1])
        overall_neff_per_time = np.array(overall_neff) / mean_inference_time

        LOGGER.info("== overall n_eff [min, median, max]===")
        LOGGER.info(str(overall_neff))
        LOGGER.info("== overall n_eff/s [min, median, max]===")
        LOGGER.info(str(overall_neff_per_time))

        trial_variable_metrics_data = xr.concat(
            [neff_data, rhat_data], pd.Index(data=["n_eff", "Rhat"], name="metric")
        )
        all_variable_metrics.append(trial_variable_metrics_data)
        all_pll.append(trial_pll)
        all_timing.append(trial_timing)
        all_samples.append(trial_samples_data)
        all_overall_neff.append(overall_neff)
        all_overall_neff_per_time.append(overall_neff_per_time)
    # merge the trial-level metrics at the PPL level
    all_variable_metrics_data = xr.concat(
        all_variable_metrics, pd.Index(data=all_names, name="ppl")
    )
    all_other_metrics_data = xr.Dataset(
        {
            "timing": (["ppl", "chain", "phase"], all_timing),
            "pll": (["ppl", "chain", "draw"], all_pll),
            "overall_neff": (["ppl", "percentile"], all_overall_neff),
            "overall_neff_per_time": (["ppl", "percentile"], all_overall_neff_per_time),
        },
        coords={
            "ppl": np.array(all_names),
            "chain": np.arange(config.trials),
            "phase": np.array(["compile", "infer"]),
            "draw": np.arange(config.num_samples),
            "percentile": np.array(["min", "median", "max"]),
        },
    )
    all_samples_data = xr.concat(all_samples, pd.Index(data=all_names, name="ppl"))
    model_cls.additional_metrics(output_dir, all_samples_data, train_data, test_data)
    LOGGER.info("all benchmark samples and metrics collected")
    # save the samples data only if requested
    if getattr(config, "save_samples", False):
        save_dataset(output_dir, "samples", all_samples_data)
    # write out thes metrics
    save_dataset(output_dir, "diagnostics", all_variable_metrics_data)
    save_dataset(output_dir, "metrics", all_other_metrics_data)
    return all_variable_metrics_data, all_other_metrics_data