Example #1
0
def rhat(model_data: TaskModel,
         less: float = None) -> Dict[str, Union[List, bool]]:
    """Function for extracting Rhat values from hbayesdm output.

    Convenience function for extracting Rhat values from hbayesdm output.
    Also possible to check if all Rhat values are less than a specified value.

    Parameters
    ----------
    model_data
        Output instance of running an hbayesdm model function.
    less
        [Optional] Upper-bound value to compare extracted Rhat values to.

    Returns
    -------
    Dict
        Keys are names of the parameters; values are their Rhat values.
        Or if `less` was specified, the dictionary values will hold `True` if
        all Rhat values (of that parameter) are less than or equal to `less`.
    """
    rhat_data = az.rhat(model_data.fit)
    if less is None:
        return {
            v.name: v.values.tolist()
            for v in rhat_data.data_vars.values()
        }
    else:
        return {
            v.name: v.values.item()
            for v in (rhat_data.max() <= less).data_vars.values()
        }
Example #2
0
def mcmc_stats(runs, burnin, prob, batch):
    """
        入力
        runs:   モンテカルロ標本
        burnin: バーンインの回数
        prob:   区間確率 (0 < prob < 1)
        batch:  乱数系列の分割数
        出力
        事後統計量のデータフレーム
    """
    traces = runs[burnin:, :]
    n = traces.shape[0] // batch
    k = traces.shape[1]
    alpha = 100 * (1.0 - prob)
    post_mean = np.mean(traces, axis=0)
    post_median = np.median(traces, axis=0)
    post_sd = np.std(traces, axis=0)
    mc_err = [az.mcse(traces[:, i].reshape((n, batch), order='F')).item(0) \
              for i in range(k)]
    ci_lower = np.percentile(traces, 0.5 * alpha, axis=0)
    ci_upper = np.percentile(traces, 100 - 0.5 * alpha, axis=0)
    hpdi = az.hdi(traces, prob)
    rhat = [az.rhat(traces[:, i].reshape((n, batch), order='F')).item(0) \
            for i in range(k)]
    stats = np.vstack((post_mean, post_median, post_sd, mc_err, ci_lower,
                       ci_upper, hpdi.T, rhat)).T
    stats_string = [
        '平均', '中央値', '標準偏差', '近似誤差', '信用区間(下限)', '信用区間(上限)', 'HPDI(下限)',
        'HPDI(上限)', '$\\hat R$'
    ]
    param_string = ['平均 $\\mu$', '分散 $\\sigma^2$']
    return pd.DataFrame(stats, index=param_string, columns=stats_string)
Example #3
0
def task_run_model(country: str, region: str, run_date: pd.Timestamp):
    """ Run the Generative model for a given region on a given date, store
        inference data into S3. """
    key = get_processed_covidtracking_key(run_date)
    with fs.open(f"{S3_BUCKET}/{key}") as file:
        df = pd.read_csv(file,
                         index_col=["region", "date"],
                         parse_dates=["date"])

    model_input = df.xs(region)
    gm = GenerativeModel(region, model_input)
    gm.sample()

    inference_data = gm.inference_data

    # Ensure no divergences
    assert (gm.n_divergences == 0
            ), f"Model {region} had {gm.n_divergences} divergences, failing."

    # Ensure convergence
    R_HAT_LIMIT = 1.1
    r_hat = az.rhat(inference_data).to_dataframe().fillna(1.0)
    assert r_hat.le(
        R_HAT_LIMIT).all().all(), f"r_hat exceeded threshold, failing."

    with tempfile.NamedTemporaryFile() as fp:
        inference_data.to_netcdf(fp.name)
        fp.seek(0)
        s3.Bucket(S3_BUCKET).upload_fileobj(
            fp, get_inference_data_key(run_date, region, country=country))
    return {"country": country, "region": region, "r_hat": r_hat}
Example #4
0
    def __call__(self, samples: tp.Dict[str, np.ndarray],
                 verbose: bool) -> bool:
        idata = az.from_pyjags(samples)
        rhat = az.rhat(idata, var_names=self.variable_names)

        maximum_rhat_deviation = max(
            abs(value['data'] - 1.0)
            for key, value in rhat.to_dict()['data_vars'].items())

        if verbose:
            print(f'maximum rhat deviation = {maximum_rhat_deviation}')

        return maximum_rhat_deviation <= self.maximum_rhat_deviation
Example #5
0
def run_convergence_checks(az_data):
    if az_data.posterior.dims['chain'] == 1:
        msg = ("Only one chain was sampled, this makes it impossible to "
               "run some convergence checks")
        return [msg]

    from arviz import rhat, ess

    ess = ess(az_data)
    rhat = rhat(az_data)
       
    warnings = []
    rhat_max = float(rhat.max()['a'].values)
    
    if rhat_max > 1.4:
        msg = ("ERROR: The rhat statistic is larger than 1.4 for some "
               "parameters. The sampler did not converge.")
        warnings.append(msg)
    elif rhat_max > 1.2:
        msg = ("WARN: The rhat statistic is larger than 1.2 for some "
               "parameters.")
        warnings.append(msg)
    elif rhat_max > 1.05:
        msg = ("INFO: The rhat statistic is larger than 1.05 for some "
               "parameters. This indicates slight problems during "
               "sampling.")
        warnings.append(msg)

    eff_min = float(ess.min()['a'].values)
    n_samples =  az_data.posterior.dims['draw'] * az_data.posterior.dims['chain']
    if eff_min < 200 and n_samples >= 500:
        msg = ("ERROR: The estimated number of effective samples is smaller than "
               "200 for some parameters.")
        warnings.append(msg)
    elif eff_min / n_samples < 0.1:
        msg = ("WARN: The number of effective samples is smaller than "
               "10% for some parameters.")
        warnings.append(msg)
    elif eff_min / n_samples < 0.25:
        msg = ("INFO: The number of effective samples is smaller than "
               "25% for some parameters.")

        warnings.append(msg)

    return warnings
def model_diagnose(model, trace, var_names):
    '''
    diagnose a model based on 'Effective Sample Size and Rhat

    :param model: a PyMC3 model
    :param trace: sample trace
    :param var_names: variable names
    :return: None
    '''
    ess = az.ess(trace, relative=True)

    print("Effective Sample Size (min across parameters)")
    for var in var_names:
        print(f"\t{var}: {ess[var].values.min()}")
    rhat = az.rhat(trace)

    print("rhat (max across parameters)")
    for var in var_names:
        print(f"\t{var}: {rhat[var].values.max()}")
Example #7
0
    def _run_convergence_checks(self, idata: arviz.InferenceData, model):
        if not hasattr(idata, "posterior"):
            msg = "No posterior samples. Unable to run convergence checks"
            warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info", None,
                                  None, None)
            self._add_warnings([warn])
            return

        if idata.posterior.sizes["chain"] == 1:
            msg = ("Only one chain was sampled, this makes it impossible to "
                   "run some convergence checks")
            warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info")
            self._add_warnings([warn])
            return

        elif idata.posterior.sizes["chain"] < 4:
            msg = (
                "We recommend running at least 4 chains for robust computation of "
                "convergence diagnostics")
            warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info")
            self._add_warnings([warn])
            return

        valid_name = [rv.name for rv in model.free_RVs + model.deterministics]
        varnames = []
        for rv in model.free_RVs:
            rv_name = rv.name
            if is_transformed_name(rv_name):
                rv_name2 = get_untransformed_name(rv_name)
                rv_name = rv_name2 if rv_name2 in valid_name else rv_name
            if rv_name in idata.posterior:
                varnames.append(rv_name)

        self._ess = ess = arviz.ess(idata, var_names=varnames)
        self._rhat = rhat = arviz.rhat(idata, var_names=varnames)

        warnings = []
        rhat_max = max(val.max() for val in rhat.values())
        if rhat_max > 1.01:
            msg = ("The rhat statistic is larger than 1.01 for some "
                   "parameters. This indicates problems during sampling. "
                   "See https://arxiv.org/abs/1903.08008 for details")
            warn = SamplerWarning(WarningType.CONVERGENCE,
                                  msg,
                                  "info",
                                  extra=rhat)
            warnings.append(warn)

        eff_min = min(val.min() for val in ess.values())
        eff_per_chain = eff_min / idata.posterior.sizes["chain"]
        if eff_per_chain < 100:
            msg = (
                "The effective sample size per chain is smaller than 100 for some parameters. "
                " A higher number is needed for reliable rhat and ess computation. "
                "See https://arxiv.org/abs/1903.08008 for details")
            warn = SamplerWarning(WarningType.CONVERGENCE,
                                  msg,
                                  "error",
                                  extra=ess)
            warnings.append(warn)

        self._add_warnings(warnings)
Example #8
0
def rhat(trace, var_name):
    rhat_samples = (az.rhat(data=trace, var_names=var_name)).to_dataframe()
    rhat_samples = pd.DataFrame({
        "rhat": rhat_samples[var_name].values,
        "param": [var_name + str(i) for i in range(len(rhat_samples))]})
    return rhat_samples
def get_rhat(samples):
    assert len(
        samples.shape) == 2  # should be (sample, chain) shape, same parameter
    return arviz.rhat(np.swapaxes(
        samples, 0,
        1))  # arxiv convention is (chain, sample) not (sample, chain) like me
extension_data_RS = np.loadtxt('extension_data_RS.csv', delimiter=',')

mean_stress_data_4 = np.loadtxt('mean_stress_data_4.csv', delimiter=',')
mean_stress_data_20 = np.loadtxt('mean_stress_data_20.csv', delimiter=',')
mean_stress_data_40 = np.loadtxt('mean_stress_data_40.csv', delimiter=',')
mean_stress_data_RS = np.loadtxt('mean_stress_data_RS.csv', delimiter=',')
std_stress_data_4 = np.loadtxt('std_stress_data_4.csv', delimiter=',')
std_stress_data_20 = np.loadtxt('std_stress_data_20.csv', delimiter=',')
std_stress_data_40 = np.loadtxt('std_stress_data_40.csv', delimiter=',')
std_stress_data_RS = np.loadtxt('std_stress_data_RS.csv', delimiter=',')

data = az.from_netcdf('save_arviz_data_stanwound')

az.style.use("default")

az.rhat(data, var_names=['kv', 'k0', 'kf', 'k2', 'b', 'mu', 'phif'])

extra_kwargs = {"color": "lightsteelblue"}

az.plot_ess(data,
            kind="local",
            var_names=['kv', 'k0', 'kf', 'k2', 'b', 'mu', 'phif'],
            figsize=(18, 18),
            color="royalblue",
            extra_kwargs=extra_kwargs,
            textsize=20)

az.plot_ess(data,
            kind="quantile",
            var_names=['kv', 'k0', 'kf', 'k2', 'b', 'mu', 'phif'],
            figsize=(18, 18),
Example #11
0
def diagnose_energy_chains(chains, burnin, thinning, num_supp, alpha,
                           save_dir):
    """Diagnoses MH chain convergence.

    Args:
        chains: Array, where each row is an independent chain.
        burnin: Int, number of iterations to burn.
        thinning: Int, number of iterations - 1 between accepted samples.
        num_supp: Int, number of support points.
        alpha: Float, privacy budget.
        save_dir: String, directory to save plot.

    Returns:
        None
    """

    chains = np.transpose(chains)  # Now each chain is a col.
    n, num_chains = chains.shape

    # Plot chains to show convergence to similar energy value.
    fig, ax = plt.subplots(1, 4, figsize=(12, 6))

    ax[0].plot(chains)
    ax[0].set_xlabel('Iteration')
    ax[0].set_ylabel('Energy')

    ax[1].plot(chains[-500:, :])
    ax[1].set_xlabel('Last 500 Iterations')
    ax[1].set_ylabel('Energy')

    # Compute Gelman-Rubin diagnostic over accumulating history, post burn-in.
    gelman_rubin_stats = []
    running_betweens = []
    running_withins = []

    #for i in range(burnin + 5, n, 100):
    range_i = np.arange(burnin + 5, n, 100)
    for i in range_i:
        chains_so_far = chains[burnin:i, :]

        between_chain_var = np.var(np.mean(chains_so_far, axis=0), ddof=1)
        within_chain_var = np.mean(np.var(chains_so_far, axis=0, ddof=1))
        running_betweens.append(between_chain_var)
        running_withins.append(within_chain_var)

        # https://arviz-devs.github.io/arviz/generated/arviz.rhat.html
        chains_so_far = np.transpose(chains[burnin:i, :])
        gelman_rubin_stat = az.rhat(chains_so_far)
        gelman_rubin_stats.append(gelman_rubin_stat)

    ax[2].plot(range_i, gelman_rubin_stats)
    ax[2].set_xlabel('Iteration')
    ax[2].set_ylabel('Gelman-Rubin')

    ax[3].plot(range_i, running_betweens, label='between')
    ax[3].plot(range_i, running_withins, label='within')
    ax[3].set_xlabel('Iteration')
    ax[3].set_ylabel('Variance')
    ax[3].legend()

    plt.tight_layout()
    plt.savefig(
        os.path.join(save_dir,
                     'mh_convergence_supp{}_eps{}.png'.format(num_supp,
                                                              alpha)))

    plt.close()
#%%

with az.rc_context(rc={'plot.max_subplots': None}):
    az.plot_trace(result_palms_5000, var_names="beta")
    plt.show()

#%%

with az.rc_context(rc={'plot.max_subplots': None}):
    az.plot_trace(result_palms, var_names=["beta", "b_raw"])
    plt.show()

#%%

az.rhat(result_palms, method="folded")

#%%
# run palms comparison multiple times

results = []
n_chains = 50

model_palms = mod.CompositionalAnalysis(data[data.obs["site"].isin(
    ["left palm", "right palm"])],
                                        "site",
                                        baseline_index=None)

for n in range(n_chains):
    result_temp = model_palms.sample_hmc(num_results=int(20000), n_burnin=5000)
                         beta=(1 - mu) * kappa,
                         observed=y)

    # sample
    trace = pm.sample(4000, tune=1000, chains=2)

# summarize results
summary_coef = np.quantile(trace.beta,
                           axis=0,
                           q=[0.5, 0.025, 0.25, 0.75, 0.975])
summary_coef = pd.DataFrame(np.transpose(summary_coef))
summary_coef.index = X.columns
summary_coef.columns = ['median', 'lower95', 'lower50', 'upper50', 'upper95']
summary_coef['P(x > 0)'] = [(trace.beta[:, i] > 0).sum() / trace.beta.shape[0]
                            for i in range(trace.beta.shape[1])]
summary_coef['rhat'] = az.rhat(trace).beta
summary_coef = summary_coef.drop(index=x_control.columns)

# plot
summary_coef['var_name'] = cov_name
summary_coef = summary_coef[::-1]
summary_coef['var_name'] = pd.Categorical(summary_coef['var_name'],
                                          categories=summary_coef['var_name'])

min_val = summary_coef.lower95.min()
max_val = summary_coef.upper95.max()
min_range = min_val - (max_val - min_val) * 0.1
max_range = max_val + (max_val - min_val) * 0.1

# point color
foo = zip(summary_coef.lower95 * summary_coef.upper95,
$$\hat{\tau} = -1 + 2 \sum_{t'=0}^K \hat{P}_{t'}$$

where $\hat{\rho}_t$ is the estimated autocorrelation at lag $t$, and $K$ is the largest integer for which $\hat{P}_{K} = \hat{\rho}_{2K} + \hat{\rho}_{2K+1}$ is still positive. The reason to compute this truncated sum, we are summing over $K$ terms instead of summing over all available terms is that for large values of $t$ the sample correlation becames too noisy to be useful, so we simply discard those terms in order to get more robust estimate.

## $\hat R$ (aka R hat, or Gelman-Rubin statistics)


Under very general conditions MCMC methods have theoretical guarantees that you will get the right answer irrespective of the starting point. Unfortunately, we only have guarantee for infinite samples. One way to get a useful estimate of convergence for finite samples is to run more than one chain, starting from very different points and then checking if the resulting chains _look similar_ to each other. $\hat R$ is a formalization of this idea and it works by comparing the the _in chain_ variance to the _between chain_ variance. Ideally we should get a valuer of 1.

Conceptually $\hat R$ can be interpreted as the overestimation of variance due to MCMC finite sampling. If you continue sampling infinitely you should get a reduction of the variance of your estimation by a $\hat R$ factor.

From a practical point of view $\hat R \lessapprox 1.01$ are considered safe

Using ArviZ we can compute it using `az.summary(⋅)`, as we already saw in the previous section or using  `az.rhat(⋅)`

az.rhat(good_chains), az.rhat(bad_chains)

## $\hat R$ in depth


The value of $\hat R$ is computed using the between-chain variance $B$ and within-chain variance $W$, and then assessing if they are different enough to worry about convergence. For $M$ chains, each of length $N$, we compute for each scalar parameter $\theta$:

\begin{split}B &= \frac{N}{M-1} \sum_{m=1}^M (\bar{\theta}_{.m} - \bar{\theta}_{..})^2 \\
W &= \frac{1}{M} \sum_{m=1}^M \left[ \frac{1}{N-1} \sum_{n=1}^n (\theta_{nm} - \bar{\theta}_{.m})^2 \right]\end{split}

where:

$\bar{\theta}_{.m} = \frac{1}{N} \sum_{n=1}^N \theta_{nm}$

$\bar{\theta}_{..} = \frac{1}{M} \sum_{m=1}^M \bar{\theta}_{.m}$
Example #15
0
def collect_samples_and_stats(
    config: SimpleNamespace,
    model_cls: Type[BaseModel],
    all_ppl_details: List[PPLDetails],
    train_data: xr.Dataset,
    test_data: xr.Dataset,
    output_dir: str,
) -> Tuple[xr.Dataset, xr.Dataset]:
    """
    :param confg: The benchmark configuration.
    :param model_cls: The model class
    :param ppl_details: For each ppl the the impl and inference classes etc.
    :param train_data: The training dataset.
    :param test_data: The held-out test dataset.
    :param output_dir: The directory for storing results.
    :returns: Two datasets:
        variable_metrics
            Coordinates: ppl, metric (n_eff, Rhat), others from model
            Data variables: from model
        other_metrics
            Coordinates: ppl, chain, draw, phase (compile, infer)
            Data variables: pll (ppl, chain, draw), timing (ppl, chain, phase)
    """
    all_variable_metrics, all_pll, all_timing, all_names = [], [], [], []
    all_samples, all_overall_neff, all_overall_neff_per_time = [], [], []
    for pplobj in all_ppl_details:
        all_names.append(pplobj.name)
        rand = np.random.RandomState(pplobj.seed)
        LOGGER.info(f"Starting inference on `{pplobj.name}` with seed {pplobj.seed}")
        # first compile the PPL Implementation this involves two steps
        compile_t1 = time.time()
        # compile step 1: instantiate ppl inference object
        infer_obj = pplobj.inference_class(pplobj.impl_class, train_data.attrs)
        # compile step 2: call compile
        infer_obj.compile(seed=rand.randint(1, int(1e7)), **pplobj.compile_args)
        compile_time = time.time() - compile_t1
        LOGGER.info(f"compiling on `{pplobj.name}` took {compile_time:.2f} secs")

        if infer_obj.is_adaptive:
            num_warmup = (
                config.num_warmup if pplobj.num_warmup is None else pplobj.num_warmup
            )
            if num_warmup > config.iterations:
                raise ValueError(
                    f"num_warmup ({num_warmup}) should be less than iterations "
                    f"({config.iterations})"
                )
        else:
            if pplobj.num_warmup:
                raise ValueError(
                    f"{pplobj.name} is not adaptive and does not accept a nonzero "
                    "num_warmup as its parameter."
                )
            else:
                num_warmup = 0

        # then run inference for each trial
        trial_samples, trial_pll, trial_timing = [], [], []
        for trialnum in range(config.trials):
            infer_t1 = time.time()
            samples = infer_obj.infer(
                data=train_data,
                iterations=config.iterations,
                num_warmup=num_warmup,
                seed=rand.randint(1, int(1e7)),
                **pplobj.infer_args,
            )
            infer_time = time.time() - infer_t1
            LOGGER.info(f"inference trial {trialnum} took {infer_time:.2f} secs")

            # Drop all NaN samples returned by PPL, which could happen when a PPL
            # needs to fill in missing warmup samples (e.g. Jags)
            valid_samples = samples.dropna("draw")
            if samples.sizes["draw"] != config.iterations:
                raise RuntimeError(
                    f"Expected {config.iterations} samples, but {samples.sizes['draw']}"
                    f" samples are returned by {pplobj.name}"
                )

            # compute the pll per sample and then convert it to the actual pll over
            # cumulative samples
            persample_pll = model_cls.evaluate_posterior_predictive(
                valid_samples, test_data
            )
            pll = np.logaddexp.accumulate(persample_pll) - np.log(
                np.arange(valid_samples.sizes["draw"]) + 1
            )
            LOGGER.info(f"PLL = {str(pll)}")
            trial_samples.append(samples)

            # After dropping the NaN samples, the length of PLL could be shorter than
            # iterations. So we'll need to pad the PLL list to make its length
            # consistent across different PPLs
            padded_pll = np.full(config.iterations, np.nan)
            padded_pll[valid_samples.draw.data] = pll
            trial_pll.append(padded_pll)
            trial_timing.append([compile_time, infer_time])
            # finally, give the inference object an opportunity
            # to write additional diagnostics
            infer_obj.additional_diagnostics(output_dir, f"{pplobj.name}_{trialnum}")
        del infer_obj
        # concatenate the samples data from each trial together so we can compute metrics
        trial_samples_data = xr.concat(
            trial_samples, pd.Index(data=np.arange(config.trials), name="chain")
        )
        # exclude warm up samples when calculating diagonostics
        trial_samples_no_warmup = trial_samples_data.isel(draw=slice(num_warmup, None))

        neff_data = arviz.ess(trial_samples_no_warmup)
        rhat_data = arviz.rhat(trial_samples_no_warmup)
        LOGGER.info(f"Trials completed for {pplobj.name}")
        LOGGER.info("== n_eff ===")
        LOGGER.info(str(neff_data.data_vars))
        LOGGER.info("==  Rhat ===")
        LOGGER.info(str(rhat_data.data_vars))

        # compute ess/time
        neff_df = neff_data.to_dataframe()
        overall_neff = [
            neff_df.values.min(),
            np.median(neff_df.values),
            neff_df.values.max(),
        ]
        mean_inference_time = np.mean(np.array(trial_timing)[:, 1])
        overall_neff_per_time = np.array(overall_neff) / mean_inference_time

        LOGGER.info("== overall n_eff [min, median, max]===")
        LOGGER.info(str(overall_neff))
        LOGGER.info("== overall n_eff/s [min, median, max]===")
        LOGGER.info(str(overall_neff_per_time))

        trial_variable_metrics_data = xr.concat(
            [neff_data, rhat_data], pd.Index(data=["n_eff", "Rhat"], name="metric")
        )
        all_variable_metrics.append(trial_variable_metrics_data)
        all_pll.append(trial_pll)
        all_timing.append(trial_timing)
        all_samples.append(trial_samples_data)
        all_overall_neff.append(overall_neff)
        all_overall_neff_per_time.append(overall_neff_per_time)
    # merge the trial-level metrics at the PPL level
    all_variable_metrics_data = xr.concat(
        all_variable_metrics, pd.Index(data=all_names, name="ppl")
    )
    all_other_metrics_data = xr.Dataset(
        {
            "timing": (["ppl", "chain", "phase"], all_timing),
            "pll": (["ppl", "chain", "draw"], all_pll),
            "overall_neff": (["ppl", "percentile"], all_overall_neff),
            "overall_neff_per_time": (["ppl", "percentile"], all_overall_neff_per_time),
        },
        coords={
            "ppl": np.array(all_names),
            "chain": np.arange(config.trials),
            "phase": np.array(["compile", "infer"]),
            "draw": np.arange(config.iterations),
            "percentile": np.array(["min", "median", "max"]),
        },
    )
    all_samples_data = xr.concat(all_samples, pd.Index(data=all_names, name="ppl"))
    model_cls.additional_metrics(output_dir, all_samples_data, train_data, test_data)
    LOGGER.info("all benchmark samples and metrics collected")
    # save the samples data only if requested
    if getattr(config, "save_samples", False):
        save_dataset(output_dir, "samples", all_samples_data)
    # write out the metrics
    save_dataset(output_dir, "diagnostics", all_variable_metrics_data)
    save_dataset(output_dir, "metrics", all_other_metrics_data)
    return all_variable_metrics_data, all_other_metrics_data
"""
trace = pm.sample(
    model=this_model,
    init="advi",
    tune=5000,
    draws=5000,
    chains=4,
    cores=4,
    progressbar=True,
)

# Save trace in case there are some problems with post processing
with open(f"./pickled/{args.iso2}.pickle", "wb") as f:
    pickle.dump((this_model, trace), f)

if az.rhat(trace).max().to_array().max() > 1.1:
    log.error("Rhat greater than 1.1")
    exit()
""" # Data post processing (submission)
We compute the sum of all new cases for the next weeks as defined here:
- https://github.com/epiforecasts/covid19-forecast-hub-europe/wiki/Forecast-format
- Epidemiological Weeks: Each week starts on Sunday and ends on Saturday

Columns in csv
--------------
forecast_date: date 
    Date as YYYY-MM-DD, last day (Monday) of submission window
scenario_id: string, optional
    One of "forecast" or a specified "scenario ID". If this column is not included it will be assumed that its value is "forecast" for all rows
target:  string  
    "# wk ahead inc case" or "# wk ahead inc death" where # is usually between 1 and 4
with baseball_model:

    theta_new = pm.Beta("theta_new",
                        alpha=phi * kappa,
                        beta=(1.0 - phi) * kappa)
    y_new = pm.Binomial("y_new", n=50, p=theta_new, observed=0)

with baseball_model:
    trace = pm.sample(2000,
                      tune=2000,
                      chains=2,
                      target_accept=0.95,
                      return_inferencedata=True)

    # check convergence diagnostics
    assert all(az.rhat(trace) < 1.03)

az.plot_trace(trace, var_names=["phi", "kappa"])

# In[21]:

player_names = data['name']

ax = az.plot_forest(trace, var_names=["thetas"])
ax[0].set_yticklabels(player_names.tolist())

# In[19]:

az.plot_trace(trace, var_names=["theta_new"])
# 4 at-bats
        eta = pm.HalfCauchy("eta", beta=1)

        cov = eta ** 2 * pm.gp.cov.Matern52(1, l_)
        gp = pm.gp.Latent(cov_func=cov)

        f = gp.prior("f", X=X)

        sigma = pm.HalfCauchy("sigma", beta=5)
        nu = pm.Gamma("nu", alpha=2, beta=0.1)
        y_ = pm.StudentT("y", mu=f, lam=1.0 / sigma, nu=nu, observed=y)

        trace = pm.sample(200, n_init=100, tune=100, chains=2, cores=2, return_inferencedata=True)
        az.to_netcdf(trace, 'src/experiments/results/lat_gp_trace')

    # check Rhat, values above 1 may indicate convergence issues
    n_nonconverged = int(np.sum(az.rhat(trace)[["eta", "l", "f_rotated_"]].to_array() > 1.03).values)
    print("%i variables MCMC chains appear not to have converged." % n_nonconverged)

    # plot the results
    fig = plt.figure(figsize=(12, 5))
    ax = fig.gca()

    # plot the samples from the gp posterior with samples and shading
    from pymc3.gp.util import plot_gp_dist

    plot_gp_dist(ax, trace.posterior["f"][0, :, :], X)

    # plot the data and the true latent function
    ax.plot(X, f_true, "dodgerblue", lw=3, label="True generating function 'f'")
    ax.plot(X, y, "ok", ms=3, label="Observed data")
Example #19
0
    def _run_convergence_checks(self, idata: arviz.InferenceData, model):
        if not hasattr(idata, 'posterior'):
            msg = "No posterior samples. Unable to run convergence checks"
            warn = SamplerWarning(WarningType.BAD_PARAMS, msg, 'info', None,
                                  None, None)
            self._add_warnings([warn])
            return

        if idata.posterior.sizes['chain'] == 1:
            msg = ("Only one chain was sampled, this makes it impossible to "
                   "run some convergence checks")
            warn = SamplerWarning(WarningType.BAD_PARAMS, msg, 'info')
            self._add_warnings([warn])
            return

        valid_name = [rv.name for rv in model.free_RVs + model.deterministics]
        varnames = []
        for rv in model.free_RVs:
            rv_name = rv.name
            if is_transformed_name(rv_name):
                rv_name2 = get_untransformed_name(rv_name)
                rv_name = rv_name2 if rv_name2 in valid_name else rv_name
            if rv_name in idata.posterior:
                varnames.append(rv_name)

        self._ess = ess = arviz.ess(idata, var_names=varnames)
        self._rhat = rhat = arviz.rhat(idata, var_names=varnames)

        warnings = []
        rhat_max = max(val.max() for val in rhat.values())
        if rhat_max > 1.4:
            msg = ("The rhat statistic is larger than 1.4 for some "
                   "parameters. The sampler did not converge.")
            warn = SamplerWarning(WarningType.CONVERGENCE,
                                  msg,
                                  'error',
                                  extra=rhat)
            warnings.append(warn)
        elif rhat_max > 1.2:
            msg = ("The rhat statistic is larger than 1.2 for some "
                   "parameters.")
            warn = SamplerWarning(WarningType.CONVERGENCE,
                                  msg,
                                  'warn',
                                  extra=rhat)
            warnings.append(warn)
        elif rhat_max > 1.05:
            msg = ("The rhat statistic is larger than 1.05 for some "
                   "parameters. This indicates slight problems during "
                   "sampling.")
            warn = SamplerWarning(WarningType.CONVERGENCE,
                                  msg,
                                  'info',
                                  extra=rhat)
            warnings.append(warn)

        eff_min = min(val.min() for val in ess.values())
        sizes = idata.posterior.sizes
        n_samples = sizes['chain'] * sizes['draw']
        if eff_min < 200 and n_samples >= 500:
            msg = ("The estimated number of effective samples is smaller than "
                   "200 for some parameters.")
            warn = SamplerWarning(WarningType.CONVERGENCE,
                                  msg,
                                  'error',
                                  extra=ess)
            warnings.append(warn)
        elif eff_min / n_samples < 0.1:
            msg = ("The number of effective samples is smaller than "
                   "10% for some parameters.")
            warn = SamplerWarning(WarningType.CONVERGENCE,
                                  msg,
                                  'warn',
                                  extra=ess)
            warnings.append(warn)
        elif eff_min / n_samples < 0.25:
            msg = ("The number of effective samples is smaller than "
                   "25% for some parameters.")
            warn = SamplerWarning(WarningType.CONVERGENCE,
                                  msg,
                                  'info',
                                  extra=ess)
            warnings.append(warn)

        self._add_warnings(warnings)
Example #20
0
def collect_samples_and_stats(
    config: SimpleNamespace,
    model_cls: Type[BaseModel],
    all_ppl_details: List[PPLDetails],
    train_data: xr.Dataset,
    test_data: xr.Dataset,
    output_dir: str,
) -> Tuple[xr.Dataset, xr.Dataset]:
    """
    :param confg: The benchmark configuration.
    :param model_cls: The model class
    :param ppl_details: For each ppl the the impl and inference classes etc.
    :param train_data: The training dataset.
    :param test_data: The held-out test dataset.
    :param output_dir: The directory for storing results.
    :returns: Two datasets:
        variable_metrics
            Coordinates: ppl, metric (n_eff, Rhat), others from model
            Data variables: from model
        other_metrics
            Coordinates: ppl, chain, draw, phase (compile, infer)
            Data variables: pll (ppl, chain, draw), timing (ppl, chain, phase)
    """
    all_variable_metrics, all_pll, all_timing, all_names = [], [], [], []
    all_samples, all_overall_neff, all_overall_neff_per_time = [], [], []
    for pplobj in all_ppl_details:
        all_names.append(pplobj.name)
        rand = np.random.RandomState(pplobj.seed)
        LOGGER.info(f"Starting inference on `{pplobj.name}` with seed {pplobj.seed}")
        # first compile the PPL Implementation this involves two steps
        compile_t1 = time.time()
        # compile step 1: instantiate ppl inference object
        infer_obj = pplobj.inference_class(pplobj.impl_class, train_data.attrs)
        # compile step 2: call compile
        infer_obj.compile(seed=rand.randint(1, 1e7), **pplobj.compile_args)
        compile_time = time.time() - compile_t1
        LOGGER.info(f"compiling on `{pplobj.name}` took {compile_time:.2f} secs")
        # then run inference for each trial
        trial_samples, trial_pll, trial_timing = [], [], []
        for trialnum in range(config.trials):
            infer_t1 = time.time()
            samples = infer_obj.infer(
                data=train_data,
                num_samples=config.num_samples,
                seed=rand.randint(1, 1e7),
                **pplobj.infer_args,
            )
            infer_time = time.time() - infer_t1
            LOGGER.info(f"inference trial {trialnum} took {infer_time:.2f} secs")
            # compute the pll per sample and then convert it to the actual pll over
            # cumulative samples
            persample_pll = model_cls.evaluate_posterior_predictive(samples, test_data)
            pll = np.logaddexp.accumulate(persample_pll) - np.log(
                np.arange(config.num_samples) + 1
            )
            LOGGER.info(f"PLL = {str(pll)}")
            trial_samples.append(samples)
            trial_pll.append(pll)
            trial_timing.append([compile_time, infer_time])
            # finally, give the inference object an opportunity
            # to write additional diagnostics
            infer_obj.additional_diagnostics(output_dir, f"{pplobj.name}_{trialnum}")
        del infer_obj
        # concatenate the samples data from each trial together so we can compute metrics
        trial_samples_data = xr.concat(
            trial_samples, pd.Index(data=np.arange(config.trials), name="chain")
        )
        neff_data = arviz.ess(trial_samples_data)
        rhat_data = arviz.rhat(trial_samples_data)
        LOGGER.info(f"Trials completed for {pplobj.name}")
        LOGGER.info("== n_eff ===")
        LOGGER.info(str(neff_data.data_vars))
        LOGGER.info("==  Rhat ===")
        LOGGER.info(str(rhat_data.data_vars))

        # compute ess/time
        neff_df = neff_data.to_dataframe()
        overall_neff = [
            neff_df.values.min(),
            np.median(neff_df.values),
            neff_df.values.max(),
        ]
        mean_inference_time = np.mean(np.array(trial_timing)[:, 1])
        overall_neff_per_time = np.array(overall_neff) / mean_inference_time

        LOGGER.info("== overall n_eff [min, median, max]===")
        LOGGER.info(str(overall_neff))
        LOGGER.info("== overall n_eff/s [min, median, max]===")
        LOGGER.info(str(overall_neff_per_time))

        trial_variable_metrics_data = xr.concat(
            [neff_data, rhat_data], pd.Index(data=["n_eff", "Rhat"], name="metric")
        )
        all_variable_metrics.append(trial_variable_metrics_data)
        all_pll.append(trial_pll)
        all_timing.append(trial_timing)
        all_samples.append(trial_samples_data)
        all_overall_neff.append(overall_neff)
        all_overall_neff_per_time.append(overall_neff_per_time)
    # merge the trial-level metrics at the PPL level
    all_variable_metrics_data = xr.concat(
        all_variable_metrics, pd.Index(data=all_names, name="ppl")
    )
    all_other_metrics_data = xr.Dataset(
        {
            "timing": (["ppl", "chain", "phase"], all_timing),
            "pll": (["ppl", "chain", "draw"], all_pll),
            "overall_neff": (["ppl", "percentile"], all_overall_neff),
            "overall_neff_per_time": (["ppl", "percentile"], all_overall_neff_per_time),
        },
        coords={
            "ppl": np.array(all_names),
            "chain": np.arange(config.trials),
            "phase": np.array(["compile", "infer"]),
            "draw": np.arange(config.num_samples),
            "percentile": np.array(["min", "median", "max"]),
        },
    )
    all_samples_data = xr.concat(all_samples, pd.Index(data=all_names, name="ppl"))
    model_cls.additional_metrics(output_dir, all_samples_data, train_data, test_data)
    LOGGER.info("all benchmark samples and metrics collected")
    # save the samples data only if requested
    if getattr(config, "save_samples", False):
        save_dataset(output_dir, "samples", all_samples_data)
    # write out thes metrics
    save_dataset(output_dir, "diagnostics", all_variable_metrics_data)
    save_dataset(output_dir, "metrics", all_other_metrics_data)
    return all_variable_metrics_data, all_other_metrics_data
Example #21
0
 def test_Rhat(self):
     with self.model:
         idata = to_inference_data(self.trace[self.burn:])
     rhat = az.rhat(idata)
     for var in rhat:
         npt.assert_allclose(rhat[var], 1, rtol=0.01)
    with model:
        check = pm.sample_prior_predictive(samples=3)

    plt.plot(x.flatten(), y, label="data", color='red')
    for i in range(check['pr'].shape[0]):
        plt.plot(x.flatten(), check['pr'][i], alpha=0.3)
    plt.legend()
    plt.savefig("experiments/plots/gp_time_prior.png", format='png')
    plt.show()

    with model:
        y_ = pm.Normal("y", mu=pr, sigma=sigma, observed=y)

    with model:
        mp = pm.find_MAP(maxeval=300)
        trace = pm.sample(
            200,
            n_init=200,
            tune=100,
            chains=2,
            cores=2,
            return_inferencedata=True,
        )
        arviz.to_netcdf(trace, 'experiments/results/gp_time_trace')

    n_nonconverged = int(
        np.sum(arviz.rhat(trace)[["sigma", "pr_rotated_"]].to_array() > 1.03).
        values)
    print("%i variables MCMC chains"
          "appear not to have converged." % n_nonconverged)
Example #23
0
s0 = []
for sample in samples:
    t0.append(np.vstack(sample.apply(lambda ent: ent["t0"].values)).T)
    s0.append(np.vstack(sample.apply(lambda ent: ent["s0"].values)).T)

t0 = np.array(t0)
s0 = np.array(s0)
# chain, draw, x_dim_0
t0 = az.convert_to_dataset(t0)
s0 = az.convert_to_dataset(s0)

trials = range(100, 2500, 100)
res_t0 = []
res_s0 = []
for i in trials:
    result = az.rhat(t0.sel(draw=slice(i, i * 2), x_dim_0=slice(None, 10)))
    res_t0.append(np.array(result.x))

    result = az.rhat(s0.sel(draw=slice(i, i * 2), x_dim_0=slice(None, 10)))
    res_s0.append(np.array(result.x))

res_t0 = np.array(res_t0)
res_s0 = np.array(res_s0)

with PdfPages(args.opt) as pp:
    fig = plt.figure()
    ax = fig.add_subplot(1, 1, 1)
    ax.plot(np.array(trials) * 2, res_t0, label=[str(i) for i in range(11)])
    ax.set_xlabel("step")
    ax.set_ylabel("\hat{R}")
    ax.set_title("Convergence of t0")
with open(f"posterior_summary_{env_name}.json") as f:
    res_posterior_summary = json.load(f)

res_posterior_summary = pd.DataFrame.from_records(res_posterior_summary,
                                                  index="variable")
res_posterior_summary.index.name = None
print(res_posterior_summary)

reference = (pd.read_csv(f"./reference_posterior_{env_name}.csv",
                         index_col=0,
                         float_precision="high").reset_index().astype(float))

# test arviz functions
funcs = {
    "rhat_rank": lambda x: az.rhat(x, method="rank"),
    "rhat_raw": lambda x: az.rhat(x, method="identity"),
    "ess_bulk": lambda x: az.ess(x, method="bulk"),
    "ess_tail": lambda x: az.ess(x, method="tail"),
    "ess_mean": lambda x: az.ess(x, method="mean"),
    "ess_sd": lambda x: az.ess(x, method="sd"),
    "ess_median": lambda x: az.ess(x, method="median"),
    "ess_raw": lambda x: az.ess(x, method="identity"),
    "ess_quantile01": lambda x: az.ess(x, method="quantile", prob=0.01),
    "ess_quantile10": lambda x: az.ess(x, method="quantile", prob=0.1),
    "ess_quantile30": lambda x: az.ess(x, method="quantile", prob=0.3),
    "mcse_mean": lambda x: az.mcse(x, method="mean"),
    "mcse_sd": lambda x: az.mcse(x, method="sd"),
    "mcse_median": lambda x: az.mcse(x, method="quantile", prob=0.5),
    "mcse_quantile01": lambda x: az.mcse(x, method="quantile", prob=0.01),
    "mcse_quantile10": lambda x: az.mcse(x, method="quantile", prob=0.1),
Example #25
0
 def test_Rhat(self):
     rhat = az.rhat(self.trace[self.burn:])
     for var in rhat:
         npt.assert_allclose(rhat[var], 1, rtol=0.01)
Example #26
0
def uni_arviz_rhat(x, var_names=None, method='folded', vars=None):
    return [
        az.rhat(x.transpose()[i].transpose(),
                var_names=var_names,
                method=method) for i in vars or range(x.shape[2])
    ]