Exemple #1
0
def _plot_ess(idata, plot_kws=None):
    default_plot_kws = dict(figsize=(16, 8))
    plot_kws = {} if plot_kws is None else plot_kws
    plot_kws = {**default_plot_kws, **plot_kws}

    axes = az.plot_ess(idata, **plot_kws)
    if axes.ndim == 1:
        axes = axes.reshape(1, -1)
    n, m = axes.shape
    for i in range(n):
        for j in range(m):
            label = axes[i, j].get_title().replace("\n", "[") + "]"
            axes[i, j].set_title(label)
            axes[i, j].xaxis.set_tick_params(labelsize=8)
            axes[i, j].yaxis.set_tick_params(labelsize=8)
            if j > 0:
                axes[i, j].set_ylabel("")
            else:
                label = axes[i, j].get_ylabel()
                axes[i, j].set_ylabel(label, size=12)

            if (i + 1) < n:
                axes[i, j].set_xlabel("")
            else:
                label = axes[i, j].get_xlabel()
                axes[i, j].set_xlabel(label, size=12)

    plt.subplots_adjust(left=0.08,
                        right=0.97,
                        top=0.97,
                        bottom=0.1,
                        wspace=0.15,
                        hspace=0.15)
    return axes
Exemple #2
0
    def ess_plot(self, chain_id, num_points=10):
        plts = az.plot_ess(self.chains, var_names=chain_id, kind='evolution', min_ess=0)
        for i in range(len(plts)):
            plot = plts[i]
            plot.axes.lines[1].remove()
            plot.axes.get_legend().remove()
            y_data = plot.axes.lines[0].get_ydata()
            y_max = np.nanmax(y_data)
            plot.axes.set_ylim(0, y_max + 10.)
            plts[i] = plot

        return plts
def mcmc_diagnostic_plots(posterior, sample_stats, it):

    az_trace = az.from_dict(posterior=posterior, sample_stats=sample_stats)
    """
    # 2 parameters or more for these pair plots
    if len(az_trace.posterior.data_vars) > 1:
        ax = az.plot_pair(az_trace, kind="hexbin", gridsize=30, marginals=True)
        fig = ax.ravel()[0].figure
        plt.ylim((5000, 30000))
        plt.xlim((1e-10, 1e-7))
        fig.savefig(f"./results/pair_plot_it{it}.png")
        plt.clf()

        ax = az.plot_pair(
            az_trace,
            kind=["scatter", "kde"],
            kde_kwargs={"fill_last": False},
            point_estimate="mean",
            marginals=True,
        )
        fig = ax.ravel()[0].figure
        fig.savefig(f"./results/point_estimate_plot_it{it}.png")
        plt.clf()
    """

    ax = az.plot_trace(az_trace, divergences=False)
    fig = ax.ravel()[0].figure
    fig.savefig(f"./results/trace_plot_it{it}.png")
    plt.clf()

    ax = az.plot_posterior(az_trace)
    fig = ax.ravel()[0].figure
    fig.savefig(f"./results/posterior_plot_it{it}.png")
    plt.clf()

    lag = np.minimum(len(list(posterior.values())[0]), 100)
    ax = az.plot_autocorr(az_trace, max_lag=lag)
    fig = ax.ravel()[0].figure
    fig.savefig(f"./results/autocorr_plot_it{it}.png")
    plt.clf()

    ax = az.plot_ess(az_trace, kind="evolution")
    fig = ax.ravel()[0].figure
    fig.savefig(f"./results/ess_evolution_plot_it{it}.png")
    plt.clf()
    plt.close()
Exemple #4
0
"""
ESS Local Plot
==============

_thumb: .7, .5
"""
import arviz as az

az.style.use("arviz-darkgrid")

idata = az.load_arviz_data("centered_eight")

az.plot_ess(idata, var_names=["mu"], kind="local", marker="_", ms=20, mew=2)
Exemple #5
0
"""
ESS Quantile Plot
=================

_thumb: .4, .5
"""
import matplotlib.pyplot as plt
import arviz as az

az.style.use("arviz-darkgrid")

idata = az.load_arviz_data("radon")

az.plot_ess(idata, var_names=["sigma_y"], kind="quantile", color="C4")

plt.show()
Exemple #6
0
def run_model(month=7,
              n_samples=1000,
              interp_type='ncs',
              binary=True,
              spike=0.9,
              hdi_prob=0.95,
              zero_inf=0.7):

    # preprocessing
    binary_str = 'binary' if binary else 'nonbinary'
    df = pd.read_csv('../data/' + interp_type + '-pop-deaths-and-' +
                     binary_str + '-mandates.csv',
                     index_col=0)
    df = df.rename(columns={
        "Age Group": "Age_Group",
        "COVID-19 Deaths": "covid_19_deaths"
    })
    test_df = df[df["Month"] == month]
    sex = np.array(test_df["Sex"])
    mandates = test_df.iloc[:,
                            -4:]  # takes all of the 4 mandate columns that currently exist
    age = test_df["Age_Group"]
    covid_deaths = test_df["covid_19_deaths"]
    population = test_df[
        "Population"] / 1000000  # makes the population in units of millions
    n = len(test_df["Age_Group"].unique()
            )  # should decrease by 1 after proper age filtering

    age_data = pd.get_dummies(test_df["Age_Group"]).drop("Under 1 year",
                                                         axis=1)
    sex_data = pd.get_dummies(test_df["Sex"], drop_first=True)

    # run the model

    with pm.Model() as model:

        # spike and slab prior
        tau = pm.InverseGamma('tau', alpha=20, beta=20)
        xi = pm.Bernoulli('xi', p=spike, shape=len(mandates.columns))
        beta_mandates = pm.MvNormal('beta_mandate',
                                    mu=0,
                                    cov=tau * np.eye(len(mandates.columns)),
                                    shape=len(mandates.columns))

        # age prior
        mu_age_mean = np.linspace(-5, 5, len(age_data.columns))
        cov = pm.HalfNormal('cov', sigma=2)
        mu_age = pm.MvNormal('mu_age',
                             mu=mu_age_mean,
                             cov=np.identity(len(age_data.columns)),
                             shape=(1, 10))
        beta_age = pm.MvNormal('beta_age',
                               mu=mu_age,
                               cov=(cov**2) * np.identity(10),
                               shape=(1, 10))

        # sex prior
        mu_sex = pm.Normal('mu_sex', mu=0, sigma=1)
        sigma_sex = pm.HalfNormal('simga_sex', sigma=2)
        beta_sex = pm.Normal('beta_sex', mu=mu_sex, sigma=sigma_sex)

        # intercept prior
        mu_intercept = pm.Normal('mu_intercept', mu=0, sigma=1)
        sigma_intercept = pm.HalfNormal('simga_intercept', sigma=2)
        beta_intercept = pm.Normal('beta_intercept',
                                   mu=mu_intercept,
                                   sigma=sigma_intercept)

        # mean setup for likelihood
        mandates = np.array(mandates).astype(theano.config.floatX)
        population = np.array(population).astype(theano.config.floatX)
        sex = np.array(sex_data).astype(theano.config.floatX)
        age = np.array(age_data).astype(theano.config.floatX)
        w_mandates = theano.shared(mandates, 'w_mandate')
        w_sex = theano.shared(sex, 'w_sex')
        w_age = theano.shared(age, 'w_age')
        mean = beta_intercept + pm.math.matrix_dot(w_mandates, xi*beta_mandates) \
                            + pm.math.matrix_dot(w_sex, beta_sex).T \
                            + pm.math.matrix_dot(w_age, beta_age.T).T

        # likelihood
        obs = pm.ZeroInflatedPoisson('y_obs',
                                     psi=zero_inf,
                                     theta=population * tt.exp(mean),
                                     observed=covid_deaths)
        # obs = pm.Normal('crap', mu=mean, sigma=3, observed=covid_deaths)

        # sample from posterior
        trace = pm.sample(n_samples,
                          tune=n_samples,
                          nuts={'target_accept': 0.98})

    # posterior hdis
    mandates = test_df.iloc[:, -4:]
    x = az.summary(trace, var_names=["beta_mandate"], hdi_prob=hdi_prob)
    x.index = mandates.columns
    x.to_csv('../images/posteriors/mandate_' + interp_type + '_' + binary_str +
             '_' + 'summary.csv')
    x = az.summary(trace, var_names=["beta_sex"], hdi_prob=hdi_prob)
    x.index = sex_data.columns
    x.to_csv('../images/posteriors/sex_' + interp_type + '_' + binary_str +
             '_' + 'summary.csv')
    x = az.summary(trace, var_names=["beta_age"], hdi_prob=hdi_prob)
    x.index = age_data.columns
    x.to_csv('../images/posteriors/age_' + interp_type + '_' + binary_str +
             '_' + 'summary.csv')
    x = az.summary(trace, var_names=["beta_intercept"], hdi_prob=hdi_prob)
    x.to_csv('../images/posteriors/intercept_' + interp_type + '_' +
             binary_str + '_' + 'summary.csv')

    # posterior distributions
    ax = az.plot_forest(trace,
                        'ridgeplot',
                        var_names=["beta_intercept"],
                        combined=True,
                        hdi_prob=0.99999)
    ax[0].set_title(r'Posterior Distribution of $\beta_0$')
    plt.savefig('../images/posteriors/intercept_posteriors_' + interp_type +
                '_' + binary_str + '.png')

    ax = az.plot_forest(trace,
                        'ridgeplot',
                        var_names=["beta_age"],
                        combined=True,
                        hdi_prob=0.99999)
    ax[0].set_yticklabels(reversed(age_data.columns))
    ax[0].set_title(r'Posterior Distribution of $\beta_{age}$')
    plt.savefig('../images/posteriors/age_posteriors_' + interp_type + '_' +
                binary_str + '.png')

    ax = az.plot_forest(trace,
                        'ridgeplot',
                        var_names=["beta_sex"],
                        combined=True,
                        hdi_prob=0.99999)
    ax[0].set_yticklabels(reversed(sex_data.columns))
    ax[0].set_title(r'Posterior Distribution of $\beta_{sex}$')
    plt.savefig('../images/posteriors/sex_posteriors_' + interp_type + '_' +
                binary_str + '.png')

    ax = az.plot_forest(trace,
                        'ridgeplot',
                        var_names=["beta_mandate"],
                        combined=True,
                        hdi_prob=0.99999)
    ax[0].set_yticklabels(reversed(mandates.columns))
    ax[0].set_title(r'Posterior Distribution of $\beta_{mandate}$')
    plt.savefig('../images/posteriors/mandate_posteriors_' + interp_type +
                '_' + binary_str + '.png')

    # ESS Plots
    ax = az.plot_ess(trace, var_names=["beta_intercept"])
    ax.set_title(r'$\beta_0$  ESS')
    plt.savefig('../images/ess/' + interp_type + '_' + binary_str +
                '_interceptESS.png')

    ax = az.plot_ess(trace, var_names=["beta_age"])
    ax[0, 0].set_title(r'$\beta_{age[1-4]}$  ESS', fontsize=18)
    ax[0, 1].set_title(r'$\beta_{age[15-24]}$  ESS', fontsize=18)
    ax[0, 2].set_title(r'$\beta_{age[25-34]}$  ESS', fontsize=18)
    ax[1, 0].set_title(r'$\beta_{age[35-44]}$  ESS', fontsize=18)
    ax[1, 1].set_title(r'$\beta_{age[45-54]}$  ESS', fontsize=18)
    ax[1, 2].set_title(r'$\beta_{age[5-14]}$  ESS', fontsize=18)
    ax[2, 0].set_title(r'$\beta_{age[55-64]}$  ESS', fontsize=18)
    ax[2, 1].set_title(r'$\beta_{age[65-74]}$  ESS', fontsize=18)
    ax[2, 2].set_title(r'$\beta_{age[75-84]}$  ESS', fontsize=18)
    ax[3, 0].set_title(r'$\beta_{age[85+]}$  ESS', fontsize=18)
    plt.savefig('../images/ess/' + interp_type + '_' + binary_str +
                '_ageESS.png')

    ax = az.plot_ess(trace, var_names=["beta_sex"])
    ax.set_title(r'$\beta_{sex}$  ESS')
    plt.savefig('../images/ess/' + interp_type + '_' + binary_str +
                '_sexESS.png')

    ax = az.plot_ess(trace, var_names=["beta_mandate"])
    ax[0].set_title(r'$\beta_{mandate[April]}$  ESS', fontsize=18)
    ax[1].set_title(r'$\beta_{mandate[May]}$  ESS', fontsize=18)
    ax[2].set_title(r'$\beta_{mandate[June]}$  ESS', fontsize=18)
    ax[3].set_title(r'$\beta_{mandate[July]}$  ESS', fontsize=18)
    plt.savefig('../images/ess/' + interp_type + '_' + binary_str +
                '_mandateESS.png')

    # posterior predictive checking
    with model:
        ppc = pm.sample_posterior_predictive(trace, var_names=["y_obs"])
    az.plot_ppc(az.from_pymc3(posterior_predictive=ppc, model=model))
    plt.savefig('../images/posterior_predictive/' + interp_type + '_' +
                binary_str + '.png')

    # return trace so that user can work with posterior data directly
    return trace
Exemple #7
0
"""
ESS Quantile Plot
=================

_thumb: .4, .5
"""
import arviz as az

idata = az.load_arviz_data("radon")

ax = az.plot_ess(idata,
                 var_names=["sigma"],
                 kind="quantile",
                 color="red",
                 backend="bokeh")
Exemple #8
0
 def plot_ess(self, **kwargs):
     """Plot quantile, local or evolution of effective sample size (ESS)."""
     return az.plot_ess(self.data, **kwargs)
# Infact, we will just the harmonic oscillator ansatz.
#We also compute the effective sample size using az.ess() from arviz package. 

for sigma in numpy.linspace(0.01, 3, 30):
    def normal_proposal(old_point):
     symmetric
        return Normal(old_point, sigma*torch.ones_like(old_point)).sample()
    tf= HarmonicTrialFunction(torch.ones(1))
    n_walkers = 2
    init_config = torch.ones(n_walkers, 1)
    results = metropolis_symmetric(tf, init_config, normal_proposal, num_walkers=n_walkers, num_steps=100000)
    dataset1 = az.convert_to_dataset(results.numpy())
    dataset2 = az.convert_to_inference_data(results.numpy())


    az.plot_ess(dataset2, kind = "local")
    plt.savefig("Local")
    az.plot_ess(dataset2, kind = "quantile")
    plt.savefig("quantile")
    az.plot_ess(dataset2, kind = "evolution")
    plt.savefig("Evolution")
    print( az.ess(dataset1).data_vars)
# In the Output_of_run array we are using units of 1000.    
#Output_of_run = numpy.array([0.02366, 1.087, 3.579, 7.21, 11.32, 15.9, 20.19, 25.2, 29.98, 32.94, 36.67, 39.41, 38.68, 42.96, 44.4, 45.35, 44.83, 45.94, 43.73, 46.34, 44.69, 45.15, 41.88,41.41, 41.33, 41, 38.46, 38.3, 37.49, 36.02]) 
#y_data = Output_of_run
#x_data = numpy.linspace(0.01, 3, 30)
#plt.scatter(x_data, y_data, c='r', label='ess scatter')
##plt.plot(x_data, y_data, label='ess fit')
#plt.xlabel('x')
#plt.ylabel('y')
#plt.title('ess vs sigma ')
"""
ESS Quantile Plot
=================

_thumb: .2, .8
"""
import matplotlib.pyplot as plt
import arviz as az

az.style.use("arviz-darkgrid")

idata = az.load_arviz_data("radon")

az.plot_ess(idata, var_names=["b"], kind="evolution")

plt.show()
std_stress_data_20 = np.loadtxt('std_stress_data_20.csv', delimiter=',')
std_stress_data_40 = np.loadtxt('std_stress_data_40.csv', delimiter=',')
std_stress_data_RS = np.loadtxt('std_stress_data_RS.csv', delimiter=',')

data = az.from_netcdf('save_arviz_data_stanwound')

az.style.use("default")

az.rhat(data, var_names=['kv', 'k0', 'kf', 'k2', 'b', 'mu', 'phif'])

extra_kwargs = {"color": "lightsteelblue"}

az.plot_ess(data,
            kind="local",
            var_names=['kv', 'k0', 'kf', 'k2', 'b', 'mu', 'phif'],
            figsize=(18, 18),
            color="royalblue",
            extra_kwargs=extra_kwargs,
            textsize=20)

az.plot_ess(data,
            kind="quantile",
            var_names=['kv', 'k0', 'kf', 'k2', 'b', 'mu', 'phif'],
            figsize=(18, 18),
            color="royalblue",
            extra_kwargs=extra_kwargs,
            textsize=20)

az.plot_ess(data,
            kind="evolution",
            var_names=['kv', 'k0', 'kf', 'k2', 'b', 'mu', 'phif'],
Exemple #12
0
def param_validate_arviz(inferred, variables):
    az.plot_mcse(inferred, var_names=variables)
    az.plot_ess(inferred, var_names=variables)
    az.plot_trace(inferred, var_names=variables)
"""
ESS Local Plot
==============

_thumb: .6, .5
"""
import arviz as az

idata = az.load_arviz_data("non_centered_eight")

ax = az.plot_ess(idata,
                 var_names=["mu"],
                 kind="local",
                 rug=True,
                 backend="bokeh")
Exemple #14
0
"""
ESS Quantile Plot
=================

_thumb: .2, .8
"""
import arviz as az

idata = az.load_arviz_data("radon")

ax = az.plot_ess(idata, var_names=["b"], kind="evolution", backend="bokeh")