コード例 #1
0
def test_unnormalized_normal_x64(kernel_cls, dense_mass):
    true_mean, true_std = 1., 0.5
    warmup_steps, num_samples = 1000, 8000

    def potential_fn(z):
        return 0.5 * np.sum(((z - true_mean) / true_std) ** 2)

    init_params = np.array(0.)
    kernel = kernel_cls(potential_fn=potential_fn, trajectory_length=8, dense_mass=dense_mass)
    mcmc = MCMC(kernel, warmup_steps, num_samples)
    mcmc.run(random.PRNGKey(0), init_params=init_params)
    hmc_states = mcmc.get_samples()
    assert_allclose(np.mean(hmc_states), true_mean, rtol=0.05)
    assert_allclose(np.std(hmc_states), true_std, rtol=0.05)

    if 'JAX_ENABLE_x64' in os.environ:
        assert hmc_states.dtype == np.float64
コード例 #2
0
ファイル: test_mcmc.py プロジェクト: while519/numpyro
def test_reuse_mcmc_run(jit_args, shape):
    y1 = np.random.normal(3, 0.1, (100,))
    y2 = np.random.normal(-3, 0.1, (shape,))

    def model(y_obs):
        mu = numpyro.sample('mu', dist.Normal(0., 1.))
        sigma = numpyro.sample("sigma", dist.HalfCauchy(3.))
        numpyro.sample("y", dist.Normal(mu, sigma), obs=y_obs)

    # Run MCMC on zero observations.
    kernel = NUTS(model)
    mcmc = MCMC(kernel, 300, 500, jit_model_args=jit_args)
    mcmc.run(random.PRNGKey(32), y1)

    # Re-run on new data - should be much faster.
    mcmc.run(random.PRNGKey(32), y2)
    assert_allclose(mcmc.get_samples()['mu'].mean(), -3., atol=0.1)
コード例 #3
0
def test_predictive_with_improper():
    true_coef = 0.9

    def model(data):
        alpha = numpyro.sample('alpha', dist.Uniform(0, 1))
        loc = numpyro.param('loc',
                            0.,
                            constraint=constraints.interval(0., alpha))
        numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)

    data = true_coef + random.normal(random.PRNGKey(0), (1000, ))
    kernel = NUTS(model=model)
    mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000)
    mcmc.run(random.PRNGKey(0), data)
    samples = mcmc.get_samples()
    obs_pred = Predictive(model, samples)(random.PRNGKey(1), data=None)["obs"]
    assert_allclose(np.mean(obs_pred), true_coef, atol=0.05)
コード例 #4
0
ファイル: test_mcmc.py プロジェクト: while519/numpyro
def test_improper_prior():
    true_mean, true_std = 1., 2.
    num_warmup, num_samples = 1000, 8000

    def model(data):
        mean = numpyro.sample('mean', dist.Normal(0, 1).mask(False))
        std = numpyro.sample('std', dist.ImproperUniform(dist.constraints.positive, (), ()))
        return numpyro.sample('obs', dist.Normal(mean, std), obs=data)

    data = dist.Normal(true_mean, true_std).sample(random.PRNGKey(1), (2000,))
    kernel = NUTS(model=model)
    mcmc = MCMC(kernel, num_warmup, num_samples)
    mcmc.warmup(random.PRNGKey(2), data)
    mcmc.run(random.PRNGKey(2), data)
    samples = mcmc.get_samples()
    assert_allclose(jnp.mean(samples['mean']), true_mean, rtol=0.05)
    assert_allclose(jnp.mean(samples['std']), true_std, rtol=0.05)
コード例 #5
0
ファイル: test_mcmc.py プロジェクト: while519/numpyro
def test_prior_with_sample_shape():
    data = {
        "J": 8,
        "y": jnp.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]),
        "sigma": jnp.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]),
    }

    def schools_model():
        mu = numpyro.sample('mu', dist.Normal(0, 5))
        tau = numpyro.sample('tau', dist.HalfCauchy(5))
        theta = numpyro.sample('theta', dist.Normal(mu, tau), sample_shape=(data['J'],))
        numpyro.sample('obs', dist.Normal(theta, data['sigma']), obs=data['y'])

    num_samples = 500
    mcmc = MCMC(NUTS(schools_model), num_warmup=500, num_samples=num_samples)
    mcmc.run(random.PRNGKey(0))
    assert mcmc.get_samples()['theta'].shape == (num_samples, data['J'])
コード例 #6
0
ファイル: test_mcmc.py プロジェクト: while519/numpyro
def test_improper_normal():
    true_coef = 0.9

    def model(data):
        alpha = numpyro.sample('alpha', dist.Uniform(0, 1))
        with numpyro.handlers.reparam(config={'loc': TransformReparam()}):
            loc = numpyro.sample('loc', dist.TransformedDistribution(
                dist.Uniform(0, 1).mask(False),
                AffineTransform(0, alpha)))
        numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)

    data = true_coef + random.normal(random.PRNGKey(0), (1000,))
    kernel = NUTS(model=model)
    mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000)
    mcmc.run(random.PRNGKey(0), data)
    samples = mcmc.get_samples()
    assert_allclose(jnp.mean(samples['loc'], 0), true_coef, atol=0.05)
コード例 #7
0
ファイル: run_fits_single.py プロジェクト: dimarkov/pybefit
    def inference(belief_sequences, obs, mask, rng_key):
        nuts_kernel = NUTS(model, dense_mass=True)
        mcmc = MCMC(nuts_kernel, num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains, chain_method="vectorized", progress_bar=False)

        mcmc.run(
            rng_key, 
            belief_sequences, 
            obs, 
            mask, 
            extra_fields=('potential_energy',)
        )
    
        samples = mcmc.get_samples()
        potential_energy = mcmc.get_extra_fields()['potential_energy'].mean()
        # mcmc.print_summary()

        return samples, potential_energy
コード例 #8
0
def run_inference(model, at_bats, hits, rng_key, args):
    if args.algo == "NUTS":
        kernel = NUTS(model)
    elif args.algo == "HMC":
        kernel = HMC(model)
    elif args.algo == "SA":
        kernel = SA(model)
    mcmc = MCMC(
        kernel,
        num_warmup=args.num_warmup,
        num_samples=args.num_samples,
        num_chains=args.num_chains,
        progress_bar=False
        if ("NUMPYRO_SPHINXBUILD" in os.environ or args.disable_progbar)
        else True,
    )
    mcmc.run(rng_key, at_bats, hits)
    return mcmc.get_samples()
コード例 #9
0
 def get_samples(rng_key, data, step_size, trajectory_length,
                 target_accept_prob):
     kernel = kernel_cls(
         model,
         step_size=step_size,
         trajectory_length=trajectory_length,
         target_accept_prob=target_accept_prob,
     )
     mcmc = MCMC(
         kernel,
         num_warmup=num_warmup,
         num_samples=num_samples,
         num_chains=2,
         chain_method=chain_method,
         progress_bar=False,
     )
     mcmc.run(rng_key, data)
     return mcmc.get_samples()
コード例 #10
0
ファイル: SPIRE.py プロジェクト: MCarmenCampos/XID_plus
def all_bands(priors,
              num_samples=500,
              num_warmup=500,
              num_chains=4,
              chain_method='parallel'):
    numpyro.set_host_device_count(4)
    nuts_kernel = NUTS(spire_model)
    mcmc = MCMC(nuts_kernel,
                num_samples=num_samples,
                num_warmup=num_warmup,
                num_chains=num_chains,
                chain_method=chain_method)
    rng_key = random.PRNGKey(0)
    mcmc.run(rng_key, priors, extra_fields=(
        'potential_energy',
        'energy',
    ))
    return mcmc
コード例 #11
0
def test_initial_inverse_mass_matrix_ndarray(dense_mass):
    def model():
        numpyro.sample("z", dist.Normal(0, 1).expand([2]))
        numpyro.sample("x", dist.Normal(0, 1).expand([3]))

    expected_mm = jnp.arange(1, 6.0)
    kernel = NUTS(
        model,
        dense_mass=dense_mass,
        inverse_mass_matrix=expected_mm,
        adapt_mass_matrix=False,
    )
    mcmc = MCMC(kernel, num_warmup=1, num_samples=1)
    mcmc.run(random.PRNGKey(0))
    inverse_mass_matrix = mcmc.last_state.adapt_state.inverse_mass_matrix
    assert set(inverse_mass_matrix.keys()) == {("x", "z")}
    expected_mm = jnp.diag(expected_mm) if dense_mass else expected_mm
    assert_allclose(inverse_mass_matrix[("x", "z")], expected_mm)
コード例 #12
0
def sample_posterior_with_predictive(rng_key: random.PRNGKey,
                                     model,
                                     data: np.ndarray,
                                     Nsamples: int = 1000,
                                     alpha: float = 1,
                                     sigma: float = 0,
                                     T: int = 10):

    kernel = NUTS(model)
    mcmc = MCMC(kernel, num_samples=Nsamples, num_warmup=NUM_WARMUP)

    mcmc.run(rng_key, data=data, alpha=alpha, sigma=sigma, T=T)
    samples = mcmc.get_samples()

    predictive = Predictive(model,
                            posterior_samples=samples,
                            return_sites=["z"])
    return predictive(rng_key, data=data, alpha=alpha, sigma=sigma, T=T)["z"]
コード例 #13
0
def test_discrete_gibbs_multiple_sites_chain(kernel, inner_kernel, kwargs, num_chains):
    def model():
        numpyro.sample("x", dist.Bernoulli(0.7).expand([3]))
        numpyro.sample("y", dist.Binomial(10, 0.3))

    sampler = kernel(inner_kernel(model), **kwargs)
    mcmc = MCMC(
        sampler,
        num_warmup=1000,
        num_samples=10000,
        num_chains=num_chains,
        progress_bar=False,
    )
    mcmc.run(random.PRNGKey(0))
    mcmc.print_summary()
    samples = mcmc.get_samples()
    assert_allclose(jnp.mean(samples["x"], 0), 0.7 * jnp.ones(3), atol=0.01)
    assert_allclose(jnp.mean(samples["y"], 0), 0.3 * 10, atol=0.1)
コード例 #14
0
def test_diverging(kernel_cls, adapt_step_size):
    data = random.normal(random.PRNGKey(0), (1000,))

    def model(data):
        loc = numpyro.sample('loc', dist.Normal(0., 1.))
        numpyro.sample('obs', dist.Normal(loc, 1), obs=data)

    kernel = kernel_cls(model, step_size=10., adapt_step_size=adapt_step_size, adapt_mass_matrix=False)
    num_warmup = num_samples = 1000
    mcmc = MCMC(kernel, num_warmup, num_samples)
    mcmc.warmup(random.PRNGKey(1), data, extra_fields=['diverging'], collect_warmup=True)
    warmup_divergences = mcmc.get_extra_fields()['diverging'].sum()
    mcmc.run(random.PRNGKey(2), data, extra_fields=['diverging'])
    num_divergences = warmup_divergences + mcmc.get_extra_fields()['diverging'].sum()
    if adapt_step_size:
        assert num_divergences <= num_warmup
    else:
        assert_allclose(num_divergences, num_warmup + num_samples)
コード例 #15
0
ファイル: hmm.py プロジェクト: freddyaboulton/numpyro
def main(args):
    print('Simulating data...')
    (transition_prior, emission_prior, transition_prob, emission_prob,
     supervised_categories, supervised_words,
     unsupervised_words) = simulate_data(
         random.PRNGKey(1),
         num_categories=args.num_categories,
         num_words=args.num_words,
         num_supervised_data=args.num_supervised,
         num_unsupervised_data=args.num_unsupervised,
     )
    print('Starting inference...')
    rng_key = random.PRNGKey(2)
    start = time.time()
    kernel = NUTS(semi_supervised_hmm)
    mcmc = MCMC(
        kernel,
        args.num_warmup,
        args.num_samples,
        num_chains=args.num_chains,
        progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    mcmc.run(rng_key, transition_prior, emission_prior, supervised_categories,
             supervised_words, unsupervised_words, args.unroll_loop)
    samples = mcmc.get_samples()
    print_results(samples, transition_prob, emission_prob)
    print('\nMCMC elapsed time:', time.time() - start)

    # make plots
    fig, ax = plt.subplots(1, 1)

    x = np.linspace(0, 1, 101)
    for i in range(transition_prob.shape[0]):
        for j in range(transition_prob.shape[1]):
            ax.plot(x,
                    gaussian_kde(samples['transition_prob'][:, i, j])(x),
                    label="trans_prob[{}, {}], true value = {:.2f}".format(
                        i, j, transition_prob[i, j]))
    ax.set(xlabel="Probability",
           ylabel="Frequency",
           title="Transition probability posterior")
    ax.legend()

    plt.savefig("hmm_plot.pdf")
    plt.tight_layout()
コード例 #16
0
ファイル: test_control_flow.py プロジェクト: dirmeier/numpyro
def test_scan():
    def model(T=10, q=1, r=1, phi=0.0, beta=0.0):
        def transition(state, i):
            x0, mu0 = state
            x1 = numpyro.sample("x", dist.Normal(phi * x0, q))
            mu1 = beta * mu0 + x1
            y1 = numpyro.sample("y", dist.Normal(mu1, r))
            numpyro.deterministic("y2", y1 * 2)
            return (x1, mu1), (x1, y1)

        mu0 = x0 = numpyro.sample("x_0", dist.Normal(0, q))
        y0 = numpyro.sample("y_0", dist.Normal(mu0, r))

        _, xy = scan(transition, (x0, mu0), jnp.arange(T))
        x, y = xy

        return jnp.append(x0, x), jnp.append(y0, y)

    T = 10
    num_samples = 100
    kernel = NUTS(model)
    mcmc = MCMC(kernel, num_warmup=100, num_samples=num_samples)
    mcmc.run(random.PRNGKey(0), T=T)
    assert set(mcmc.get_samples()) == {"x", "y", "y2", "x_0", "y_0"}
    mcmc.print_summary()

    samples = mcmc.get_samples()
    x = samples.pop("x")[0]  # take 1 sample of x
    # this tests for the composition of condition and substitute
    # this also tests if we can use `vmap` for predictive.
    future = 5
    predictive = Predictive(
        numpyro.handlers.condition(model, {"x": x}),
        samples,
        return_sites=["x", "y", "y2"],
        parallel=True,
    )
    result = predictive(random.PRNGKey(1), T=T + future)
    expected_shape = (num_samples, T + future)
    assert result["x"].shape == expected_shape
    assert result["y"].shape == expected_shape
    assert result["y2"].shape == expected_shape
    assert_allclose(result["x"][:, :T], jnp.broadcast_to(x, (num_samples, T)))
    assert_allclose(result["y"][:, :T], samples["y"])
コード例 #17
0
def fit_numpyro(progress_bar=False,
                model=None,
                num_warmup=1000,
                n_draws=200,
                num_chains=4,
                sampler=NUTS,
                use_gpu=False,
                **kwargs):
    if 'bayes_window_test_mode' in os.environ:
        # Override settings with minimal
        use_gpu = False
        num_warmup = 5
        n_draws = 5
        num_chains = 1
    select_device(use_gpu, num_chains)
    model = model or models.model_hierarchical
    mcmc = MCMC(
        sampler(
            model=model,
            find_heuristic_step_size=True,
            target_accept_prob=0.99,
            # init_strategy=numpyro.infer.init_to_uniform
        ),
        num_warmup=num_warmup,
        num_samples=n_draws,
        num_chains=num_chains,
        progress_bar=progress_bar,
        chain_method='parallel')
    mcmc.run(jax.random.PRNGKey(16), **kwargs)

    # arviz convert
    try:
        trace = az.from_numpyro(mcmc)
    except AttributeError:
        trace = az.from_dict(mcmc.get_samples())
        print(trace.posterior)

    # Print diagnostics
    if 'sample_stats' in trace:
        if trace.sample_stats.diverging.sum(['chain', 'draw']).values > 0:
            print(
                f"n(Divergences) = {trace.sample_stats.diverging.sum(['chain', 'draw']).values}"
            )
    return trace, mcmc
コード例 #18
0
ファイル: base.py プロジェクト: gcgibson/covid
    def infer(self, num_warmup=1000, num_samples=1000, num_chains=1, rng_key=PRNGKey(1), **args):
        '''Fit using MCMC'''
        
        args = dict(self.args, **args)
        
        kernel = NUTS(self, init_strategy = numpyro.infer.util.init_to_median())

        mcmc = MCMC(kernel, 
                    num_warmup=num_warmup, 
                    num_samples=num_samples, 
                    num_chains=num_chains)
             
        mcmc.run(rng_key, **self.obs, **args)    
        mcmc.print_summary()
        
        self.mcmc = mcmc
        self.mcmc_samples = mcmc.get_samples()
    
        return self.mcmc_samples
コード例 #19
0
def run_inference(design_matrix: jnp.ndarray, outcome: jnp.ndarray,
                  rng_key: jnp.ndarray,
                  num_warmup: int,
                  num_samples: int, num_chains: int,
                  interval_size: float = 0.95) -> None:
    """
    Estimate the effect size.
    """

    kernel = NUTS(model)
    mcmc = MCMC(kernel, num_warmup, num_samples, num_chains,
                progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    mcmc.run(rng_key, design_matrix, outcome)

    # 0th column is intercept (not getting called)
    # 1st column is effect of getting called
    # 2nd column is effect of gender (should be none since assigned at random)
    coef = mcmc.get_samples()['coefficients']
    print_results(coef, interval_size)
コード例 #20
0
def test_initial_inverse_mass_matrix(dense_mass):
    def model():
        numpyro.sample("x", dist.Normal(0, 1).expand([3]))
        numpyro.sample("z", dist.Normal(0, 1).expand([2]))

    expected_mm = jnp.arange(1, 4.0)
    kernel = NUTS(
        model,
        dense_mass=dense_mass,
        inverse_mass_matrix={("x",): expected_mm},
        adapt_mass_matrix=False,
    )
    mcmc = MCMC(kernel, 1, 1)
    mcmc.run(random.PRNGKey(0))
    inverse_mass_matrix = mcmc.last_state.adapt_state.inverse_mass_matrix
    assert set(inverse_mass_matrix.keys()) == {("x",), ("z",)}
    expected_mm = jnp.diag(expected_mm) if dense_mass else expected_mm
    assert_allclose(inverse_mass_matrix[("x",)], expected_mm)
    assert_allclose(inverse_mass_matrix[("z",)], jnp.ones(2))
    def _run_inference(self, rng_key=None, X=None, Y=None):
        ''' Run inference on the model specified above with the supplied data '''

        if rng_key is None:
            rng_key = random.PRNGKey(self.random_state)

        if self.num_chains > 1:
            rng_key_ = random.split(rng_key, self.num_chains)
        else:
            rng_key, rng_key_ = random.split(rng_key)

        # The following samples parameter settings with NUTS and MCMC to fit the posterior based on the provided data (X,Y)
        start = time.time()
        kernel = NUTS(self._model)
        mcmc = MCMC(kernel, self.num_warmup, self.num_samples)
        mcmc.run(rng_key_, X=X, Y=Y)
        print('/n MCMC elapsed time:', time.time() - start)

        return mcmc.get_samples()
コード例 #22
0
def test_chain_smoke(chain_method, compile_args):
    def model(data):
        concentration = np.array([1.0, 1.0, 1.0])
        p_latent = numpyro.sample('p_latent', dist.Dirichlet(concentration))
        numpyro.sample('obs', dist.Categorical(p_latent), obs=data)
        return p_latent

    data = dist.Categorical(np.array([0.1, 0.6,
                                      0.3])).sample(random.PRNGKey(1),
                                                    (2000, ))
    kernel = NUTS(model)
    mcmc = MCMC(kernel,
                2,
                5,
                num_chains=2,
                chain_method=chain_method,
                jit_model_args=compile_args)
    mcmc.warmup(random.PRNGKey(0), data)
    mcmc.run(random.PRNGKey(1), data)
コード例 #23
0
def mcmc_inference(model, num_warmup, num_samples, num_chains, rng_key, X, Y):
    """"
    Helper function for doing NUTS inference.
    :param model: a parametric function proportional to the posterior (see gp_regression.likelihood).
    :param num_warmup: warmup steps.
    :param num_samples: number of samples.
    :param num_chains: number of Markov chains used for MCMC sampling.
    :param rng_key: random seed.
    :param X: X data.
    :param Y: Y data.
    :return: Dictionary key: name of parameter (from defined in model), value: list of samples.
    """
    start = time.time()
    kernel = NUTS(model)
    mcmc = MCMC(kernel, num_warmup, num_samples, num_chains=num_chains)
    mcmc.run(rng_key, X, Y)
    print('\nMCMC time:', time.time() - start)
    print(mcmc.print_summary())
    return mcmc.get_samples()
コード例 #24
0
def test_uniform_normal():
    true_coef = 0.9
    num_warmup, num_samples = 1000, 1000

    def model(data):
        alpha = numpyro.sample('alpha', dist.Uniform(0, 1))
        loc = numpyro.sample('loc', dist.Uniform(0, alpha))
        numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)

    data = true_coef + random.normal(random.PRNGKey(0), (1000, ))
    kernel = NUTS(model=model)
    mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples)
    mcmc.warmup(random.PRNGKey(2), data, collect_warmup=True)
    warmup_samples = mcmc.get_samples()
    mcmc.run(random.PRNGKey(3), data)
    samples = mcmc.get_samples()
    assert len(warmup_samples['loc']) == num_warmup
    assert len(samples['loc']) == num_samples
    assert_allclose(np.mean(samples['loc'], 0), true_coef, atol=0.05)
コード例 #25
0
def test_dirichlet_categorical_x64(kernel_cls, dense_mass):
    warmup_steps, num_samples = 100, 20000

    def model(data):
        concentration = np.array([1.0, 1.0, 1.0])
        p_latent = numpyro.sample('p_latent', dist.Dirichlet(concentration))
        numpyro.sample('obs', dist.Categorical(p_latent), obs=data)
        return p_latent

    true_probs = np.array([0.1, 0.6, 0.3])
    data = dist.Categorical(true_probs).sample(random.PRNGKey(1), (2000, ))
    kernel = kernel_cls(model, trajectory_length=1., dense_mass=dense_mass)
    mcmc = MCMC(kernel, warmup_steps, num_samples, progress_bar=False)
    mcmc.run(random.PRNGKey(2), data)
    samples = mcmc.get_samples()
    assert_allclose(np.mean(samples['p_latent'], 0), true_probs, atol=0.02)

    if 'JAX_ENABLE_X64' in os.environ:
        assert samples['p_latent'].dtype == np.float64
コード例 #26
0
def test_bernoulli_latent_model():
    def model(data):
        y_prob = numpyro.sample("y_prob", dist.Beta(1., 1.))
        with numpyro.plate("data", data.shape[0]):
            y = numpyro.sample("y", dist.Bernoulli(y_prob))
            z = numpyro.sample("z", dist.Bernoulli(0.65 * y + 0.1))
            numpyro.sample("obs", dist.Normal(2. * z, 1.), obs=data)

    N = 2000
    y_prob = 0.3
    y = dist.Bernoulli(y_prob).sample(random.PRNGKey(0), (N,))
    z = dist.Bernoulli(0.65 * y + 0.1).sample(random.PRNGKey(1))
    data = dist.Normal(2. * z, 1.0).sample(random.PRNGKey(2))

    nuts_kernel = NUTS(model)
    mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=500)
    mcmc.run(random.PRNGKey(3), data)
    samples = mcmc.get_samples()
    assert_allclose(samples["y_prob"].mean(0), y_prob, atol=0.05)
コード例 #27
0
def test_predictive(parallel):
    model, data, true_probs = beta_bernoulli()
    mcmc = MCMC(NUTS(model), num_warmup=100, num_samples=100)
    mcmc.run(random.PRNGKey(0), data)
    samples = mcmc.get_samples()
    predictive = Predictive(model, samples, parallel=parallel)
    predictive_samples = predictive(random.PRNGKey(1))
    assert predictive_samples.keys() == {"beta_sq", "obs"}

    predictive.return_sites = ["beta", "beta_sq", "obs"]
    predictive_samples = predictive(random.PRNGKey(1))
    # check shapes
    assert predictive_samples["beta"].shape == (100, ) + true_probs.shape
    assert predictive_samples["beta_sq"].shape == (100, ) + true_probs.shape
    assert predictive_samples["obs"].shape == (100, ) + data.shape
    # check sample mean
    obs = predictive_samples["obs"].reshape((-1, ) + true_probs.shape).astype(
        np.float32)
    assert_allclose(obs.mean(0), true_probs, rtol=0.1)
コード例 #28
0
def test_logistic_regression_x64(kernel_cls):
    N, dim = 3000, 3
    if kernel_cls is SA:
        num_warmup, num_samples = (100000, 100000)
    elif kernel_cls is BarkerMH:
        num_warmup, num_samples = (2000, 12000)
    else:
        num_warmup, num_samples = (1000, 8000)
    data = random.normal(random.PRNGKey(0), (N, dim))
    true_coefs = jnp.arange(1.0, dim + 1.0)
    logits = jnp.sum(true_coefs * data, axis=-1)
    labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

    def model(labels):
        coefs = numpyro.sample("coefs",
                               dist.Normal(jnp.zeros(dim), jnp.ones(dim)))
        logits = numpyro.deterministic("logits", jnp.sum(coefs * data,
                                                         axis=-1))
        return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels)

    if kernel_cls is SA:
        kernel = SA(model=model, adapt_state_size=9)
    elif kernel_cls is BarkerMH:
        kernel = BarkerMH(model=model)
    else:
        kernel = kernel_cls(model=model,
                            trajectory_length=8,
                            find_heuristic_step_size=True)
    mcmc = MCMC(kernel,
                num_warmup=num_warmup,
                num_samples=num_samples,
                progress_bar=False)
    mcmc.run(random.PRNGKey(2), labels)
    mcmc.print_summary()
    samples = mcmc.get_samples()
    assert samples["logits"].shape == (num_samples, N)
    # those coefficients are found by doing MAP inference using AutoDelta
    expected_coefs = jnp.array([0.97, 2.05, 3.18])
    assert_allclose(jnp.mean(samples["coefs"], 0), expected_coefs, atol=0.1)

    if "JAX_ENABLE_X64" in os.environ:
        assert samples["coefs"].dtype == jnp.float64
コード例 #29
0
ファイル: test_tfp.py プロジェクト: gully/numpyro
def test_beta_bernoulli():
    from numpyro.contrib.tfp import distributions as dist

    warmup_steps, num_samples = (500, 2000)

    def model(data):
        alpha = jnp.array([1.1, 1.1])
        beta = jnp.array([1.1, 1.1])
        p_latent = numpyro.sample('p_latent', dist.Beta(alpha, beta))
        numpyro.sample('obs', dist.Bernoulli(p_latent), obs=data)
        return p_latent

    true_probs = jnp.array([0.9, 0.1])
    data = dist.Bernoulli(true_probs)(rng_key=random.PRNGKey(1), sample_shape=(1000, 2))
    kernel = NUTS(model=model, trajectory_length=0.1)
    mcmc = MCMC(kernel, num_warmup=warmup_steps, num_samples=num_samples)
    mcmc.run(random.PRNGKey(2), data)
    mcmc.print_summary()
    samples = mcmc.get_samples()
    assert_allclose(jnp.mean(samples['p_latent'], 0), true_probs, atol=0.05)
コード例 #30
0
ファイル: hmm.py プロジェクト: jwschroeder3/numpyro
def main(args):
    print('Simulating data...')
    (transition_prior, emission_prior, transition_prob, emission_prob,
     supervised_categories, supervised_words, unsupervised_words) = simulate_data(
        random.PRNGKey(1),
        num_categories=args.num_categories,
        num_words=args.num_words,
        num_supervised_data=args.num_supervised,
        num_unsupervised_data=args.num_unsupervised,
    )
    print('Starting inference...')
    rng_key = random.PRNGKey(2)
    start = time.time()
    kernel = NUTS(semi_supervised_hmm)
    mcmc = MCMC(kernel, args.num_warmup, args.num_samples)
    mcmc.run(rng_key, transition_prior, emission_prior, supervised_categories,
             supervised_words, unsupervised_words)
    samples = mcmc.get_samples()
    print_results(samples, transition_prob, emission_prob)
    print('\nMCMC elapsed time:', time.time() - start)