示例#1
0
def j_summary(samples, ctype='hist', properties={'width': 800}):
    #     print(vcov)
    if type(samples) == dict:
        print_summary(samples, 0.89, False)
        df = pd.DataFrame(samples).clean_names()
    else:
        print_summary(dict(zip(samples.columns, samples.T.values)), 0.89,
                      False)
        df = samples
    display(df.corr())
    df = df if len(df) < 5000 else df.sample(n=4000)
    base = alt.Chart(df).properties(height=30)

    if ctype == 'density':
        l = [
            base.mark_line().transform_density(
                row,
                as_=[row, 'density'],
            ).encode(alt.X(f'{row}:Q'), alt.Y('density:Q'))
            for row in df.columns
        ]

        density = alt.vconcat(*l)
        return_chart = density

    if ctype == 'hist':
        hist = base.mark_bar().encode(
            alt.X(bin=alt.Bin(maxbins=20),
                  field=alt.repeat("row"),
                  type='quantitative'),
            y=alt.Y(title=None, aggregate='count',
                    type='quantitative')).repeat(row=[c for c in df.columns])
        return_chart = hist

    display(return_chart)
示例#2
0
    def print_summary(self, prob=0.9, exclude_deterministic=True):
        """
        Print the statistics of posterior samples collected during running this MCMC instance.

        :param float prob: the probability mass of samples within the credible interval.
        :param bool exclude_deterministic: whether or not print out the statistics
            at deterministic sites.
        """
        # Exclude deterministic sites by default
        sites = self._states[self._sample_field]
        if isinstance(sites, dict) and exclude_deterministic:
            state_sample_field = attrgetter(self._sample_field)(
                self._last_state)
            # XXX: there might be the case that state.z is not a dictionary but
            # its postprocessed value `sites` is a dictionary.
            # TODO: in general, when both `sites` and `state.z` are dictionaries,
            # they can have different key names, not necessary due to deterministic
            # behavior. We might revise this logic if needed in the future.
            if isinstance(state_sample_field, dict):
                sites = {
                    k: v
                    for k, v in self._states[self._sample_field].items()
                    if k in state_sample_field
                }
        print_summary(sites, prob=prob)
        extra_fields = self.get_extra_fields()
        if "diverging" in extra_fields:
            print("Number of divergences: {}".format(
                jnp.sum(extra_fields["diverging"])))
示例#3
0
 def print_summary(self, prob=0.9, exclude_deterministic=True):
     # Exclude deterministic sites by default
     sites = self._states[self._sample_field]
     if isinstance(sites, dict) and exclude_deterministic:
         sites = {
             k: v
             for k, v in self._states[self._sample_field].items()
             if k in self._last_state.z
         }
     print_summary(sites, prob=prob)
     extra_fields = self.get_extra_fields()
     if 'diverging' in extra_fields:
         print("Number of divergences: {}".format(
             jnp.sum(extra_fields['diverging'])))
示例#4
0
def analyze_post(post, method):
    print_summary(post, 0.95, False)
    fig, ax = plt.subplots()
    az.plot_forest(post, hdi_prob=0.95, figsize=(10, 4), ax=ax)
    plt.title(method)
    pml.savefig(f'multicollinear_forest_plot_{method}.pdf')
    plt.show()

    # post = m6_1.sample_posterior(random.PRNGKey(1), p6_1, (1000,))
    fig, ax = plt.subplots()
    az.plot_pair(post, var_names=["br", "bl"],
                 scatter_kwargs={"alpha": 0.1}, ax=ax)
    pml.savefig(f'multicollinear_joint_post_{method}.pdf')
    plt.title(method)
    plt.show()

    sum_blbr = post["bl"] + post["br"]
    fig, ax = plt.subplots()
    az.plot_kde(sum_blbr, label="sum of bl and br", ax=ax)
    plt.title(method)
    pml.savefig(f'multicollinear_sum_post_{method}.pdf')
    plt.show()
示例#5
0
文件: neutra.py 项目: afcarl/numpyro
def main(args):
    print("Start vanilla HMC...")
    nuts_kernel = NUTS(dual_moon_model)
    mcmc = MCMC(
        nuts_kernel,
        args.num_warmup,
        args.num_samples,
        num_chains=args.num_chains,
        progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    mcmc.run(random.PRNGKey(0))
    mcmc.print_summary()
    vanilla_samples = mcmc.get_samples()['x'].copy()

    guide = AutoBNAFNormal(
        dual_moon_model,
        hidden_factors=[args.hidden_factor, args.hidden_factor])
    svi = SVI(dual_moon_model, guide, optim.Adam(0.003), ELBO())
    svi_state = svi.init(random.PRNGKey(1))

    print("Start training guide...")
    last_state, losses = lax.scan(lambda state, i: svi.update(state),
                                  svi_state, jnp.zeros(args.num_iters))
    params = svi.get_params(last_state)
    print("Finish training guide. Extract samples...")
    guide_samples = guide.sample_posterior(
        random.PRNGKey(2), params,
        sample_shape=(args.num_samples, ))['x'].copy()

    print("\nStart NeuTra HMC...")
    neutra = NeuTraReparam(guide, params)
    neutra_model = neutra.reparam(dual_moon_model)
    nuts_kernel = NUTS(neutra_model)
    mcmc = MCMC(
        nuts_kernel,
        args.num_warmup,
        args.num_samples,
        num_chains=args.num_chains,
        progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    mcmc.run(random.PRNGKey(3))
    mcmc.print_summary()
    zs = mcmc.get_samples(group_by_chain=True)["auto_shared_latent"]
    print("Transform samples into unwarped space...")
    samples = neutra.transform_sample(zs)
    print_summary(samples)
    zs = zs.reshape(-1, 2)
    samples = samples['x'].reshape(-1, 2).copy()

    # make plots

    # guide samples (for plotting)
    guide_base_samples = dist.Normal(jnp.zeros(2),
                                     1.).sample(random.PRNGKey(4), (1000, ))
    guide_trans_samples = neutra.transform_sample(guide_base_samples)['x']

    x1 = jnp.linspace(-3, 3, 100)
    x2 = jnp.linspace(-3, 3, 100)
    X1, X2 = jnp.meshgrid(x1, x2)
    P = jnp.exp(DualMoonDistribution().log_prob(jnp.stack([X1, X2], axis=-1)))

    fig = plt.figure(figsize=(12, 8), constrained_layout=True)
    gs = GridSpec(2, 3, figure=fig)
    ax1 = fig.add_subplot(gs[0, 0])
    ax2 = fig.add_subplot(gs[1, 0])
    ax3 = fig.add_subplot(gs[0, 1])
    ax4 = fig.add_subplot(gs[1, 1])
    ax5 = fig.add_subplot(gs[0, 2])
    ax6 = fig.add_subplot(gs[1, 2])

    ax1.plot(losses[1000:])
    ax1.set_title('Autoguide training loss\n(after 1000 steps)')

    ax2.contourf(X1, X2, P, cmap='OrRd')
    sns.kdeplot(guide_samples[:, 0], guide_samples[:, 1], n_levels=30, ax=ax2)
    ax2.set(xlim=[-3, 3],
            ylim=[-3, 3],
            xlabel='x0',
            ylabel='x1',
            title='Posterior using\nAutoBNAFNormal guide')

    sns.scatterplot(guide_base_samples[:, 0],
                    guide_base_samples[:, 1],
                    ax=ax3,
                    hue=guide_trans_samples[:, 0] < 0.)
    ax3.set(
        xlim=[-3, 3],
        ylim=[-3, 3],
        xlabel='x0',
        ylabel='x1',
        title='AutoBNAFNormal base samples\n(True=left moon; False=right moon)'
    )

    ax4.contourf(X1, X2, P, cmap='OrRd')
    sns.kdeplot(vanilla_samples[:, 0],
                vanilla_samples[:, 1],
                n_levels=30,
                ax=ax4)
    ax4.plot(vanilla_samples[-50:, 0],
             vanilla_samples[-50:, 1],
             'bo-',
             alpha=0.5)
    ax4.set(xlim=[-3, 3],
            ylim=[-3, 3],
            xlabel='x0',
            ylabel='x1',
            title='Posterior using\nvanilla HMC sampler')

    sns.scatterplot(zs[:, 0],
                    zs[:, 1],
                    ax=ax5,
                    hue=samples[:, 0] < 0.,
                    s=30,
                    alpha=0.5,
                    edgecolor="none")
    ax5.set(xlim=[-5, 5],
            ylim=[-5, 5],
            xlabel='x0',
            ylabel='x1',
            title='Samples from the\nwarped posterior - p(z)')

    ax6.contourf(X1, X2, P, cmap='OrRd')
    sns.kdeplot(samples[:, 0], samples[:, 1], n_levels=30, ax=ax6)
    ax6.plot(samples[-50:, 0], samples[-50:, 1], 'bo-', alpha=0.2)
    ax6.set(xlim=[-3, 3],
            ylim=[-3, 3],
            xlabel='x0',
            ylabel='x1',
            title='Posterior using\nNeuTra HMC sampler')

    plt.savefig("neutra.pdf")
示例#6
0
def main(args):
    print("Start vanilla HMC...")
    nuts_kernel = NUTS(dual_moon_model)
    mcmc = MCMC(nuts_kernel, args.num_warmup, args.num_samples)
    mcmc.run(random.PRNGKey(0))
    mcmc.print_summary()
    vanilla_samples = mcmc.get_samples()['x'].copy()

    adam = optim.Adam(0.01)
    # TODO: it is hard to find good hyperparameters such that IAF guide can learn this model.
    # We will use BNAF instead!
    guide = AutoIAFNormal(dual_moon_model,
                          num_flows=2,
                          hidden_dims=[args.num_hidden, args.num_hidden])
    svi = SVI(dual_moon_model, guide, adam, AutoContinuousELBO())
    svi_state = svi.init(random.PRNGKey(1))

    print("Start training guide...")
    last_state, losses = lax.scan(lambda state, i: svi.update(state),
                                  svi_state, np.zeros(args.num_iters))
    params = svi.get_params(last_state)
    print("Finish training guide. Extract samples...")
    guide_samples = guide.sample_posterior(
        random.PRNGKey(0), params,
        sample_shape=(args.num_samples, ))['x'].copy()

    transform = guide.get_transform(params)
    _, potential_fn, constrain_fn = initialize_model(random.PRNGKey(2),
                                                     dual_moon_model)
    transformed_potential_fn = partial(transformed_potential_energy,
                                       potential_fn, transform)
    transformed_constrain_fn = lambda x: constrain_fn(transform(x)
                                                      )  # noqa: E731

    print("\nStart NeuTra HMC...")
    nuts_kernel = NUTS(potential_fn=transformed_potential_fn)
    mcmc = MCMC(nuts_kernel, args.num_warmup, args.num_samples)
    init_params = np.zeros(guide.latent_size)
    mcmc.run(random.PRNGKey(3), init_params=init_params)
    mcmc.print_summary()
    zs = mcmc.get_samples()
    print("Transform samples into unwarped space...")
    samples = vmap(transformed_constrain_fn)(zs)
    print_summary(tree_map(lambda x: x[None, ...], samples))
    samples = samples['x'].copy()

    # make plots

    # guide samples (for plotting)
    guide_base_samples = dist.Normal(np.zeros(2),
                                     1.).sample(random.PRNGKey(4), (1000, ))
    guide_trans_samples = vmap(transformed_constrain_fn)(
        guide_base_samples)['x']

    x1 = np.linspace(-3, 3, 100)
    x2 = np.linspace(-3, 3, 100)
    X1, X2 = np.meshgrid(x1, x2)
    P = np.exp(DualMoonDistribution().log_prob(np.stack([X1, X2], axis=-1)))

    fig = plt.figure(figsize=(12, 16), constrained_layout=True)
    gs = GridSpec(3, 2, figure=fig)
    ax1 = fig.add_subplot(gs[0, 0])
    ax2 = fig.add_subplot(gs[0, 1])
    ax3 = fig.add_subplot(gs[1, 0])
    ax4 = fig.add_subplot(gs[1, 1])
    ax5 = fig.add_subplot(gs[2, 0])
    ax6 = fig.add_subplot(gs[2, 1])

    ax1.plot(np.log(losses[1000:]))
    ax1.set_title('Autoguide training log loss (after 1000 steps)')

    ax2.contourf(X1, X2, P, cmap='OrRd')
    sns.kdeplot(guide_samples[:, 0], guide_samples[:, 1], n_levels=30, ax=ax2)
    ax2.set(xlim=[-3, 3],
            ylim=[-3, 3],
            xlabel='x0',
            ylabel='x1',
            title='Posterior using AutoIAFNormal guide')

    sns.scatterplot(guide_base_samples[:, 0],
                    guide_base_samples[:, 1],
                    ax=ax3,
                    hue=guide_trans_samples[:, 0] < 0.)
    ax3.set(
        xlim=[-3, 3],
        ylim=[-3, 3],
        xlabel='x0',
        ylabel='x1',
        title='AutoIAFNormal base samples (True=left moon; False=right moon)')

    ax4.contourf(X1, X2, P, cmap='OrRd')
    sns.kdeplot(vanilla_samples[:, 0],
                vanilla_samples[:, 1],
                n_levels=30,
                ax=ax4)
    ax4.plot(vanilla_samples[-50:, 0],
             vanilla_samples[-50:, 1],
             'bo-',
             alpha=0.5)
    ax4.set(xlim=[-3, 3],
            ylim=[-3, 3],
            xlabel='x0',
            ylabel='x1',
            title='Posterior using vanilla HMC sampler')

    sns.scatterplot(zs[:, 0],
                    zs[:, 1],
                    ax=ax5,
                    hue=samples[:, 0] < 0.,
                    s=30,
                    alpha=0.5,
                    edgecolor="none")
    ax5.set(xlim=[-5, 5],
            ylim=[-5, 5],
            xlabel='x0',
            ylabel='x1',
            title='Samples from the warped posterior - p(z)')

    ax6.contourf(X1, X2, P, cmap='OrRd')
    sns.kdeplot(samples[:, 0], samples[:, 1], n_levels=30, ax=ax6)
    ax6.plot(samples[-50:, 0], samples[-50:, 1], 'bo-', alpha=0.2)
    ax6.set(xlim=[-3, 3],
            ylim=[-3, 3],
            xlabel='x0',
            ylabel='x1',
            title='Posterior using NeuTra HMC sampler')

    plt.savefig("neutra.pdf")
    plt.close()
示例#7
0
 def print_summary(self, prob=0.9):
     print_summary(self._states['z'], prob=prob)
     extra_fields = self.get_extra_fields()
     if 'diverging' in extra_fields:
         print("Number of divergences: {}".format(
             np.sum(extra_fields['diverging'])))
svi = SVI(model,
          m5_3,
          optim.Adam(1),
          Trace_ELBO(),
          M=d.M.values,
          A=d.A.values,
          D=d.D.values)
p5_3, losses = svi.run(random.PRNGKey(0), 1000)
post = m5_3.sample_posterior(random.PRNGKey(1), p5_3, (1000, ))

# Posterior

param_names = {'a', 'bA', 'bM', 'sigma'}
for p in param_names:
    print(f'posterior for {p}')
    print_summary(post[p], 0.95, False)

# PPC

# call predictive without specifying new data
# so it uses original data
post = m5_3.sample_posterior(random.PRNGKey(1), p5_3, (int(1e4), ))
post_pred = Predictive(m5_3.model, post)(random.PRNGKey(2),
                                         M=d.M.values,
                                         A=d.A.values)
mu = post_pred["mu"]

# summarize samples across cases
mu_mean = jnp.mean(mu, 0)
mu_PI = jnp.percentile(mu, q=(5.5, 94.5), axis=0)