示例#1
0
def test_reparam_log_joint(model, kwargs):
    guide = AutoIAFNormal(model)
    guide(**kwargs)
    neutra = NeuTraReparam(guide)
    reparam_model = neutra.reparam(model)
    _, pe_fn, transforms, _ = initialize_model(model, model_kwargs=kwargs)
    init_params, pe_fn_neutra, _, _ = initialize_model(
        reparam_model, model_kwargs=kwargs
    )
    latent_x = list(init_params.values())[0]
    transformed_params = neutra.transform_sample(latent_x)
    pe_transformed = pe_fn_neutra(init_params)
    neutra_transform = ComposeTransform(guide.get_posterior(**kwargs).transforms)
    latent_y = neutra_transform(latent_x)
    log_det_jacobian = neutra_transform.log_abs_det_jacobian(latent_x, latent_y)
    pe = pe_fn({k: transforms[k](v) for k, v in transformed_params.items()})
    assert_close(pe_transformed, pe - log_det_jacobian)
示例#2
0
def test_neals_funnel_smoke(jit):
    dim = 10

    guide = AutoIAFNormal(neals_funnel)
    svi = SVI(neals_funnel, guide, optim.Adam({"lr": 1e-10}), Trace_ELBO())
    for _ in range(1000):
        svi.step(dim)

    neutra = NeuTraReparam(guide.requires_grad_(False))
    model = neutra.reparam(neals_funnel)
    nuts = NUTS(model, jit_compile=jit)
    mcmc = MCMC(nuts, num_samples=50, warmup_steps=50)
    mcmc.run(dim)
    samples = mcmc.get_samples()
    # XXX: `MCMC.get_samples` adds a leftmost batch dim to all sites, not uniformly at -max_plate_nesting-1;
    # hence the unsqueeze
    transformed_samples = neutra.transform_sample(
        samples['y_shared_latent'].unsqueeze(-2))
    assert 'x' in transformed_samples
    assert 'y' in transformed_samples
示例#3
0
def test_neals_funnel_smoke():
    dim = 10

    def model():
        y = pyro.sample('y', dist.Normal(0, 3))
        with pyro.plate("D", dim):
            pyro.sample('x', dist.Normal(0, torch.exp(y/2)))

    guide = AutoIAFNormal(model)
    svi = SVI(model, guide,  optim.Adam({"lr": 1e-10}), Trace_ELBO())
    for _ in range(1000):
        svi.step()

    neutra = NeuTraReparam(guide)
    model = neutra.reparam(model)
    nuts = NUTS(model)
    mcmc = MCMC(nuts, num_samples=50, warmup_steps=50)
    mcmc.run()
    samples = mcmc.get_samples()
    # XXX: `MCMC.get_samples` adds a leftmost batch dim to all sites, not uniformly at -max_plate_nesting-1;
    # hence the unsqueeze
    transformed_samples = neutra.transform_sample(samples['y_shared_latent'].unsqueeze(-2))
    assert 'x' in transformed_samples
    assert 'y' in transformed_samples
示例#4
0
文件: neutra.py 项目: zeta1999/pyro
def main(args):
    pyro.set_rng_seed(args.rng_seed)
    fig = plt.figure(figsize=(8, 16), constrained_layout=True)
    gs = GridSpec(4, 2, figure=fig)
    ax1 = fig.add_subplot(gs[0, 0])
    ax2 = fig.add_subplot(gs[0, 1])
    ax3 = fig.add_subplot(gs[1, 0])
    ax4 = fig.add_subplot(gs[2, 0])
    ax5 = fig.add_subplot(gs[3, 0])
    ax6 = fig.add_subplot(gs[1, 1])
    ax7 = fig.add_subplot(gs[2, 1])
    ax8 = fig.add_subplot(gs[3, 1])
    xlim = tuple(int(x) for x in args.x_lim.strip().split(','))
    ylim = tuple(int(x) for x in args.y_lim.strip().split(','))
    assert len(xlim) == 2
    assert len(ylim) == 2

    # 1. Plot samples drawn from BananaShaped distribution
    x1, x2 = torch.meshgrid(
        [torch.linspace(*xlim, 100),
         torch.linspace(*ylim, 100)])
    d = BananaShaped(args.param_a, args.param_b)
    p = torch.exp(d.log_prob(torch.stack([x1, x2], dim=-1)))
    ax1.contourf(
        x1,
        x2,
        p,
        cmap='OrRd',
    )
    ax1.set(xlabel='x0',
            ylabel='x1',
            xlim=xlim,
            ylim=ylim,
            title='BananaShaped distribution: \nlog density')

    # 2. Run vanilla HMC
    logging.info('\nDrawing samples using vanilla HMC ...')
    mcmc = run_hmc(args, model)
    vanilla_samples = mcmc.get_samples()['x'].cpu().numpy()
    ax2.contourf(x1, x2, p, cmap='OrRd')
    ax2.set(xlabel='x0',
            ylabel='x1',
            xlim=xlim,
            ylim=ylim,
            title='Posterior \n(vanilla HMC)')
    sns.kdeplot(vanilla_samples[:, 0], vanilla_samples[:, 1], ax=ax2)

    # 3(a). Fit a diagonal normal autoguide
    logging.info('\nFitting a DiagNormal autoguide ...')
    guide = AutoDiagonalNormal(model, init_scale=0.05)
    fit_guide(guide, args)
    with pyro.plate('N', args.num_samples):
        guide_samples = guide()['x'].detach().cpu().numpy()

    ax3.contourf(x1, x2, p, cmap='OrRd')
    ax3.set(xlabel='x0',
            ylabel='x1',
            xlim=xlim,
            ylim=ylim,
            title='Posterior \n(DiagNormal autoguide)')
    sns.kdeplot(guide_samples[:, 0], guide_samples[:, 1], ax=ax3)

    # 3(b). Draw samples using NeuTra HMC
    logging.info(
        '\nDrawing samples using DiagNormal autoguide + NeuTra HMC ...')
    neutra = NeuTraReparam(guide.requires_grad_(False))
    neutra_model = poutine.reparam(model, config=lambda _: neutra)
    mcmc = run_hmc(args, neutra_model)
    zs = mcmc.get_samples()['x_shared_latent']
    sns.scatterplot(zs[:, 0], zs[:, 1], alpha=0.2, ax=ax4)
    ax4.set(xlabel='x0',
            ylabel='x1',
            title='Posterior (warped) samples \n(DiagNormal + NeuTra HMC)')

    samples = neutra.transform_sample(zs)
    samples = samples['x'].cpu().numpy()
    ax5.contourf(x1, x2, p, cmap='OrRd')
    ax5.set(xlabel='x0',
            ylabel='x1',
            xlim=xlim,
            ylim=ylim,
            title='Posterior (transformed) \n(DiagNormal + NeuTra HMC)')
    sns.kdeplot(samples[:, 0], samples[:, 1], ax=ax5)

    # 4(a). Fit a BNAF autoguide
    logging.info('\nFitting a BNAF autoguide ...')
    guide = AutoNormalizingFlow(
        model, partial(iterated, args.num_flows, block_autoregressive))
    fit_guide(guide, args)
    with pyro.plate('N', args.num_samples):
        guide_samples = guide()['x'].detach().cpu().numpy()

    ax6.contourf(x1, x2, p, cmap='OrRd')
    ax6.set(xlabel='x0',
            ylabel='x1',
            xlim=xlim,
            ylim=ylim,
            title='Posterior \n(BNAF autoguide)')
    sns.kdeplot(guide_samples[:, 0], guide_samples[:, 1], ax=ax6)

    # 4(b). Draw samples using NeuTra HMC
    logging.info('\nDrawing samples using BNAF autoguide + NeuTra HMC ...')
    neutra = NeuTraReparam(guide.requires_grad_(False))
    neutra_model = poutine.reparam(model, config=lambda _: neutra)
    mcmc = run_hmc(args, neutra_model)
    zs = mcmc.get_samples()['x_shared_latent']
    sns.scatterplot(zs[:, 0], zs[:, 1], alpha=0.2, ax=ax7)
    ax7.set(xlabel='x0',
            ylabel='x1',
            title='Posterior (warped) samples \n(BNAF + NeuTra HMC)')

    samples = neutra.transform_sample(zs)
    samples = samples['x'].cpu().numpy()
    ax8.contourf(x1, x2, p, cmap='OrRd')
    ax8.set(xlabel='x0',
            ylabel='x1',
            xlim=xlim,
            ylim=ylim,
            title='Posterior (transformed) \n(BNAF + NeuTra HMC)')
    sns.kdeplot(samples[:, 0], samples[:, 1], ax=ax8)

    plt.savefig(os.path.join(os.path.dirname(__file__), 'neutra.pdf'))
示例#5
0
def test_init():
    guide = AutoIAFNormal(neals_funnel)
    guide()

    check_init_reparam(neals_funnel, NeuTraReparam(guide))