コード例 #1
0
def test_predictive_with_improper():
    true_coef = 0.9

    def model(data):
        alpha = numpyro.sample('alpha', dist.Uniform(0, 1))
        with handlers.reparam(config={'loc': TransformReparam()}):
            loc = numpyro.sample(
                'loc',
                dist.TransformedDistribution(
                    dist.Uniform(0, 1).mask(False), AffineTransform(0, alpha)))
        numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)

    data = true_coef + random.normal(random.PRNGKey(0), (1000, ))
    kernel = NUTS(model=model)
    mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000)
    mcmc.run(random.PRNGKey(0), data)
    samples = mcmc.get_samples()
    obs_pred = Predictive(model, samples)(random.PRNGKey(1), data=None)["obs"]
    assert_allclose(jnp.mean(obs_pred), true_coef, atol=0.05)
コード例 #2
0
def main(args):
    print('Simulating data...')
    (transition_prior, emission_prior, transition_prob, emission_prob,
     supervised_categories, supervised_words,
     unsupervised_words) = simulate_data(
         random.PRNGKey(1),
         num_categories=args.num_categories,
         num_words=args.num_words,
         num_supervised_data=args.num_supervised,
         num_unsupervised_data=args.num_unsupervised,
     )
    print('Starting inference...')
    rng_key = random.PRNGKey(2)
    start = time.time()
    kernel = NUTS(semi_supervised_hmm)
    mcmc = MCMC(
        kernel,
        args.num_warmup,
        args.num_samples,
        num_chains=args.num_chains,
        progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    mcmc.run(rng_key, transition_prior, emission_prior, supervised_categories,
             supervised_words, unsupervised_words, args.unroll_loop)
    samples = mcmc.get_samples()
    print_results(samples, transition_prob, emission_prob)
    print('\nMCMC elapsed time:', time.time() - start)

    # make plots
    fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True)

    x = np.linspace(0, 1, 101)
    for i in range(transition_prob.shape[0]):
        for j in range(transition_prob.shape[1]):
            ax.plot(x,
                    gaussian_kde(samples['transition_prob'][:, i, j])(x),
                    label="trans_prob[{}, {}], true value = {:.2f}".format(
                        i, j, transition_prob[i, j]))
    ax.set(xlabel="Probability",
           ylabel="Frequency",
           title="Transition probability posterior")
    ax.legend()

    plt.savefig("hmm_plot.pdf")
コード例 #3
0
def run_inference(design_matrix: jnp.ndarray, outcome: jnp.ndarray,
                  rng_key: jnp.ndarray,
                  num_warmup: int,
                  num_samples: int, num_chains: int,
                  interval_size: float = 0.95) -> None:
    """
    Estimate the effect size.
    """

    kernel = NUTS(model)
    mcmc = MCMC(kernel, num_warmup, num_samples, num_chains,
                progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    mcmc.run(rng_key, design_matrix, outcome)

    # 0th column is intercept (not getting called)
    # 1st column is effect of getting called
    # 2nd column is effect of gender (should be none since assigned at random)
    coef = mcmc.get_samples()['coefficients']
    print_results(coef, interval_size)
コード例 #4
0
ファイル: test_mcmc.py プロジェクト: xidulu/numpyro
def test_prior_with_sample_shape():
    data = {
        "J": 8,
        "y": jnp.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]),
        "sigma": jnp.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]),
    }

    def schools_model():
        mu = numpyro.sample('mu', dist.Normal(0, 5))
        tau = numpyro.sample('tau', dist.HalfCauchy(5))
        theta = numpyro.sample('theta',
                               dist.Normal(mu, tau),
                               sample_shape=(data['J'], ))
        numpyro.sample('obs', dist.Normal(theta, data['sigma']), obs=data['y'])

    num_samples = 500
    mcmc = MCMC(NUTS(schools_model), num_warmup=500, num_samples=num_samples)
    mcmc.run(random.PRNGKey(0))
    assert mcmc.get_samples()['theta'].shape == (num_samples, data['J'])
コード例 #5
0
ファイル: test_mcmc.py プロジェクト: fehiepsi/numpyro
def test_improper_normal(max_tree_depth):
    true_coef = 0.9

    def model(data):
        alpha = numpyro.sample("alpha", dist.Uniform(0, 1))
        with numpyro.handlers.reparam(config={"loc": TransformReparam()}):
            loc = numpyro.sample(
                "loc",
                dist.TransformedDistribution(
                    dist.Uniform(0, 1).mask(False), AffineTransform(0, alpha)),
            )
        numpyro.sample("obs", dist.Normal(loc, 0.1), obs=data)

    data = true_coef + random.normal(random.PRNGKey(0), (1000, ))
    kernel = NUTS(model=model, max_tree_depth=max_tree_depth)
    mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000)
    mcmc.run(random.PRNGKey(0), data)
    samples = mcmc.get_samples()
    assert_allclose(jnp.mean(samples["loc"], 0), true_coef, atol=0.05)
コード例 #6
0
ファイル: test_funsor.py プロジェクト: while519/numpyro
def test_bernoulli_latent_model():
    def model(data):
        y_prob = numpyro.sample("y_prob", dist.Beta(1., 1.))
        with numpyro.plate("data", data.shape[0]):
            y = numpyro.sample("y", dist.Bernoulli(y_prob))
            z = numpyro.sample("z", dist.Bernoulli(0.65 * y + 0.1))
            numpyro.sample("obs", dist.Normal(2. * z, 1.), obs=data)

    N = 2000
    y_prob = 0.3
    y = dist.Bernoulli(y_prob).sample(random.PRNGKey(0), (N, ))
    z = dist.Bernoulli(0.65 * y + 0.1).sample(random.PRNGKey(1))
    data = dist.Normal(2. * z, 1.0).sample(random.PRNGKey(2))

    nuts_kernel = NUTS(model)
    mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=500)
    mcmc.run(random.PRNGKey(3), data)
    samples = mcmc.get_samples()
    assert_allclose(samples["y_prob"].mean(0), y_prob, atol=0.05)
コード例 #7
0
ファイル: test_mcmc.py プロジェクト: fehiepsi/numpyro
def test_initial_inverse_mass_matrix(dense_mass):
    def model():
        numpyro.sample("x", dist.Normal(0, 1).expand([3]))
        numpyro.sample("z", dist.Normal(0, 1).expand([2]))

    expected_mm = jnp.arange(1, 4.0)
    kernel = NUTS(
        model,
        dense_mass=dense_mass,
        inverse_mass_matrix={("x", ): expected_mm},
        adapt_mass_matrix=False,
    )
    mcmc = MCMC(kernel, num_warmup=1, num_samples=1)
    mcmc.run(random.PRNGKey(0))
    inverse_mass_matrix = mcmc.last_state.adapt_state.inverse_mass_matrix
    assert set(inverse_mass_matrix.keys()) == {("x", ), ("z", )}
    expected_mm = jnp.diag(expected_mm) if dense_mass else expected_mm
    assert_allclose(inverse_mass_matrix[("x", )], expected_mm)
    assert_allclose(inverse_mass_matrix[("z", )], jnp.ones(2))
コード例 #8
0
ファイル: test_infer_util.py プロジェクト: hessammehr/numpyro
def test_predictive(parallel):
    model, data, true_probs = beta_bernoulli()
    mcmc = MCMC(NUTS(model), num_warmup=100, num_samples=100)
    mcmc.run(random.PRNGKey(0), data)
    samples = mcmc.get_samples()
    predictive = Predictive(model, samples, parallel=parallel)
    predictive_samples = predictive(random.PRNGKey(1))
    assert predictive_samples.keys() == {"beta_sq", "obs"}

    predictive.return_sites = ["beta", "beta_sq", "obs"]
    predictive_samples = predictive(random.PRNGKey(1))
    # check shapes
    assert predictive_samples["beta"].shape == (100, ) + true_probs.shape
    assert predictive_samples["beta_sq"].shape == (100, ) + true_probs.shape
    assert predictive_samples["obs"].shape == (100, ) + data.shape
    # check sample mean
    obs = predictive_samples["obs"].reshape((-1, ) + true_probs.shape).astype(
        np.float32)
    assert_allclose(obs.mean(0), true_probs, rtol=0.1)
コード例 #9
0
ファイル: test_mcmc.py プロジェクト: xidulu/numpyro
def test_dirichlet_categorical_x64(kernel_cls, dense_mass):
    warmup_steps, num_samples = 100, 20000

    def model(data):
        concentration = jnp.array([1.0, 1.0, 1.0])
        p_latent = numpyro.sample('p_latent', dist.Dirichlet(concentration))
        numpyro.sample('obs', dist.Categorical(p_latent), obs=data)
        return p_latent

    true_probs = jnp.array([0.1, 0.6, 0.3])
    data = dist.Categorical(true_probs).sample(random.PRNGKey(1), (2000, ))
    kernel = kernel_cls(model, trajectory_length=1., dense_mass=dense_mass)
    mcmc = MCMC(kernel, warmup_steps, num_samples, progress_bar=False)
    mcmc.run(random.PRNGKey(2), data)
    samples = mcmc.get_samples()
    assert_allclose(jnp.mean(samples['p_latent'], 0), true_probs, atol=0.02)

    if 'JAX_ENABLE_X64' in os.environ:
        assert samples['p_latent'].dtype == jnp.float64
コード例 #10
0
def test_random_module_mcmc(backend, init):

    if backend == "flax":
        import flax

        linear_module = flax.linen.Dense(features=1)
        bias_name = "bias"
        weight_name = "kernel"
        random_module = random_flax_module
        kwargs_name = "inputs"
    elif backend == "haiku":
        import haiku as hk

        linear_module = hk.transform(lambda x: hk.Linear(1)(x))
        bias_name = "linear.b"
        weight_name = "linear.w"
        random_module = random_haiku_module
        kwargs_name = "x"

    N, dim = 3000, 3
    num_warmup, num_samples = (1000, 1000)
    data = random.normal(random.PRNGKey(0), (N, dim))
    true_coefs = np.arange(1.0, dim + 1.0)
    logits = np.sum(true_coefs * data, axis=-1)
    labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

    if init == "shape":
        kwargs = {"input_shape": (3,)}
    elif init == "kwargs":
        kwargs = {kwargs_name: data}

    def model(data, labels):
        nn = random_module(
            "nn",
            linear_module,
            {bias_name: dist.Cauchy(), weight_name: dist.Normal()},
            **kwargs
        )
        logits = nn(data).squeeze(-1)
        numpyro.sample("y", dist.Bernoulli(logits=logits), obs=labels)

    kernel = NUTS(model=model)
    mcmc = MCMC(
        kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False
    )
    mcmc.run(random.PRNGKey(2), data, labels)
    mcmc.print_summary()
    samples = mcmc.get_samples()
    assert set(samples.keys()) == {
        "nn/{}".format(bias_name),
        "nn/{}".format(weight_name),
    }
    assert_allclose(
        np.mean(samples["nn/{}".format(weight_name)].squeeze(-1), 0),
        true_coefs,
        atol=0.22,
    )
コード例 #11
0
ファイル: gp.py プロジェクト: xidulu/numpyro
def run_inference(model, args, rng_key, X, Y):
    start = time.time()
    kernel = NUTS(model)
    mcmc = MCMC(kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains,
                progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    mcmc.run(rng_key, X, Y)
    mcmc.print_summary()
    print('\nMCMC elapsed time:', time.time() - start)
    return mcmc.get_samples()
コード例 #12
0
    def init_mcmc(
        self,
        model,
        num_warmup: int = 1000,
        num_samples: int = 1000,
        num_chains: int = 1,
        sampler="NUTS",
        sampler_kwargs={},
        **kwargs,
    ) -> MCMC:
        """Initialises the MCMC sampler.

        Args:
            model (callable): [desc]
            num_warmup (int): [description]. Defaults to 1000.
            num_samples (int): [description]. Defaults to 1000.
            num_chains (int): [description]. Defaults to 1.
            sampler (str, or numpyro.infer.mcmc.MCMCKernel): Choose one of
                ['NUTS'], or pass a numpyro mcmc kernel.
            sampler_kwargs (dict): Keyword arguments to pass to the chosen
                sampler.
            **kwargs: Keyword arguments to pass to mcmc instance.

        """
        self._mcmc_support_warnings()

        if isinstance(sampler, str):
            sampler = sampler.lower()
            if sampler != "nuts":
                raise ValueError(f"Sampler '{sampler}' not supported.")
            target_accept_prob = sampler_kwargs.pop("target_accept_prob", 0.98)
            init_strategy = sampler_kwargs.pop(
                "init_strategy",
                lambda site=None: init_to_median(site=site, num_samples=100),
            )
            step_size = sampler_kwargs.pop("step_size", 0.1)
            sampler = NUTS(
                model,
                target_accept_prob=target_accept_prob,
                init_strategy=init_strategy,
                step_size=step_size,
                **sampler_kwargs,
            )

        # if num_chains > 1:
        # self.batch_ndims = 2  # I.e. two dims for chains then samples

        return MCMC(
            sampler,
            num_warmup=num_warmup,
            num_samples=num_samples,
            num_chains=num_chains,
            **kwargs,
        )
コード例 #13
0
ファイル: my_numpyro.py プロジェクト: yongduek/DBDA-python
class Sampler():
    def __init__(self, model, data=None):
        self.data = data
        self.num_warmup = 1000
        self.num_samples = 2000
        self.num_chains = 4
        self.mcmc = MCMC(NUTS(model),
                         num_warmup=self.num_warmup,
                         num_samples=self.num_samples,
                         num_chains=self.num_chains)
        self.data = data

    def fit(self, data):
        self.data = data
        self.mcmc.run(random.PRNGKey(0), **data)
        self.post = self.mcmc.get_samples()
        return self.post  # posterior samples

    def predict(self, data):
        pass
コード例 #14
0
ファイル: test_hmc_gibbs.py プロジェクト: hessammehr/numpyro
def test_discrete_gibbs_gmm_1d(modified, kernel, inner_kernel, kwargs):
    def model(probs, locs):
        c = numpyro.sample("c", dist.Categorical(probs))
        numpyro.sample("x", dist.Normal(locs[c], 0.5))

    probs = jnp.array([0.15, 0.3, 0.3, 0.25])
    locs = jnp.array([-2, 0, 2, 4])
    sampler = kernel(inner_kernel(model, trajectory_length=1.2),
                     modified=modified,
                     **kwargs)
    mcmc = MCMC(sampler,
                num_warmup=1000,
                num_samples=200000,
                progress_bar=False)
    mcmc.run(random.PRNGKey(0), probs, locs)
    samples = mcmc.get_samples()
    assert_allclose(jnp.mean(samples["x"]), 1.3, atol=0.1)
    assert_allclose(jnp.var(samples["x"]), 4.36, atol=0.4)
    assert_allclose(jnp.mean(samples["c"]), 1.65, atol=0.1)
    assert_allclose(jnp.var(samples["c"]), 1.03, atol=0.1)
コード例 #15
0
def test_beta_bernoulli_x64(kernel_cls):
    warmup_steps, num_samples = 500, 20000

    def model(data):
        alpha = np.array([1.1, 1.1])
        beta = np.array([1.1, 1.1])
        p_latent = numpyro.sample('p_latent', dist.Beta(alpha, beta))
        numpyro.sample('obs', dist.Bernoulli(p_latent), obs=data)
        return p_latent

    true_probs = np.array([0.9, 0.1])
    data = dist.Bernoulli(true_probs).sample(random.PRNGKey(1), (1000, 2))
    kernel = kernel_cls(model=model, trajectory_length=1.)
    mcmc = MCMC(kernel, num_warmup=warmup_steps, num_samples=num_samples, progress_bar=False)
    mcmc.run(random.PRNGKey(2), data)
    samples = mcmc.get_samples()
    assert_allclose(np.mean(samples['p_latent'], 0), true_probs, atol=0.05)

    if 'JAX_ENABLE_x64' in os.environ:
        assert samples['p_latent'].dtype == np.float64
コード例 #16
0
ファイル: hmm.py プロジェクト: jwschroeder3/numpyro
def main(args):
    print('Simulating data...')
    (transition_prior, emission_prior, transition_prob, emission_prob,
     supervised_categories, supervised_words, unsupervised_words) = simulate_data(
        random.PRNGKey(1),
        num_categories=args.num_categories,
        num_words=args.num_words,
        num_supervised_data=args.num_supervised,
        num_unsupervised_data=args.num_unsupervised,
    )
    print('Starting inference...')
    rng_key = random.PRNGKey(2)
    start = time.time()
    kernel = NUTS(semi_supervised_hmm)
    mcmc = MCMC(kernel, args.num_warmup, args.num_samples)
    mcmc.run(rng_key, transition_prior, emission_prior, supervised_categories,
             supervised_words, unsupervised_words)
    samples = mcmc.get_samples()
    print_results(samples, transition_prob, emission_prob)
    print('\nMCMC elapsed time:', time.time() - start)
コード例 #17
0
def test_correlated_mvn():
    # This requires dense mass matrix estimation.
    D = 5

    warmup_steps, num_samples = 5000, 8000

    true_mean = 0.
    a = jnp.tril(0.5 * jnp.fliplr(jnp.eye(D)) + 0.1 * jnp.exp(random.normal(random.PRNGKey(0), shape=(D, D))))
    true_cov = jnp.dot(a, a.T)
    true_prec = jnp.linalg.inv(true_cov)

    def potential_fn(z):
        return 0.5 * jnp.dot(z.T, jnp.dot(true_prec, z))

    init_params = jnp.zeros(D)
    kernel = NUTS(potential_fn=potential_fn, dense_mass=True)
    mcmc = MCMC(kernel, warmup_steps, num_samples)
    mcmc.run(random.PRNGKey(0), init_params=init_params)
    samples = mcmc.get_samples()
    assert_allclose(jnp.mean(samples), true_mean, atol=0.02)
    assert np.sum(np.abs(np.cov(samples.T) - true_cov)) / D**2 < 0.02
コード例 #18
0
def test_binomial_stable_x64(with_logits):
    # Ref: https://github.com/pyro-ppl/pyro/issues/1706
    warmup_steps, num_samples = 200, 200

    def model(data):
        p = numpyro.sample('p', dist.Beta(1., 1.))
        if with_logits:
            logits = logit(p)
            numpyro.sample('obs', dist.Binomial(data['n'], logits=logits), obs=data['x'])
        else:
            numpyro.sample('obs', dist.Binomial(data['n'], probs=p), obs=data['x'])

    data = {'n': 5000000, 'x': 3849}
    kernel = NUTS(model=model)
    mcmc = MCMC(kernel, warmup_steps, num_samples)
    mcmc.run(random.PRNGKey(2), data)
    samples = mcmc.get_samples()
    assert_allclose(jnp.mean(samples['p'], 0), data['x'] / data['n'], rtol=0.05)

    if 'JAX_ENABLE_X64' in os.environ:
        assert samples['p'].dtype == jnp.float64
コード例 #19
0
ファイル: capture_recapture.py プロジェクト: while519/numpyro
def run_inference(model, capture_history, sex, rng_key, args):
    if args.algo == "NUTS":
        kernel = NUTS(model)
    elif args.algo == "HMC":
        kernel = HMC(model)
    mcmc = MCMC(kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains,
                progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    mcmc.run(rng_key, capture_history, sex)
    mcmc.print_summary()
    return mcmc.get_samples()
コード例 #20
0
def test_logistic_regression_x64(kernel_cls):
    N, dim = 3000, 3
    warmup_steps, num_samples = 1000, 8000
    data = random.normal(random.PRNGKey(0), (N, dim))
    true_coefs = np.arange(1., dim + 1.)
    logits = np.sum(true_coefs * data, axis=-1)
    labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

    def model(labels):
        coefs = numpyro.sample('coefs', dist.Normal(np.zeros(dim), np.ones(dim)))
        logits = np.sum(coefs * data, axis=-1)
        return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels)

    kernel = kernel_cls(model=model, trajectory_length=8)
    mcmc = MCMC(kernel, warmup_steps, num_samples)
    mcmc.run(random.PRNGKey(2), labels)
    samples = mcmc.get_samples()
    assert_allclose(np.mean(samples['coefs'], 0), true_coefs, atol=0.22)

    if 'JAX_ENABLE_x64' in os.environ:
        assert samples['coefs'].dtype == np.float64
コード例 #21
0
def sample(model,
           num_samples,
           num_warmup,
           num_chains,
           seed=0,
           chain_method="parallel",
           summary=True,
           **kwargs):
    """Run the No-U-Turn sampler
    """
    rng_key = random.PRNGKey(seed)
    kernel = NUTS(model)
    # Note: sampling more than one chain doesn't show a progress bar
    mcmc = MCMC(kernel,
                num_warmup,
                num_samples,
                num_chains,
                chain_method=chain_method)
    mcmc.run(rng_key, **kwargs)

    if summary:
        mcmc.print_summary()

    # Return a fitted MCMC object
    return mcmc
コード例 #22
0
    def test_inference_data_constant_data(self):
        import numpyro
        import numpyro.distributions as dist
        from numpyro.infer import MCMC, NUTS

        x1 = 10
        x2 = 12
        y1 = np.random.randn(10)

        def model_constant_data(x, y1=None):
            _x = numpyro.sample("x", dist.Normal(1, 3))
            numpyro.sample("y1", dist.Normal(x * _x, 1), obs=y1)

        nuts_kernel = NUTS(model_constant_data)
        mcmc = MCMC(nuts_kernel, num_samples=10, num_warmup=2)
        mcmc.run(PRNGKey(0), x=x1, y1=y1)
        posterior = mcmc.get_samples()
        posterior_predictive = Predictive(model_constant_data,
                                          posterior)(PRNGKey(1), x1)
        predictions = Predictive(model_constant_data, posterior)(PRNGKey(2),
                                                                 x2)
        inference_data = from_numpyro(
            mcmc,
            posterior_predictive=posterior_predictive,
            predictions=predictions,
            constant_data={"x1": x1},
            predictions_constant_data={"x2": x2},
        )
        test_dict = {
            "posterior": ["x"],
            "posterior_predictive": ["y1"],
            "sample_stats": ["diverging"],
            "log_likelihood": ["y1"],
            "predictions": ["y1"],
            "observed_data": ["y1"],
            "constant_data": ["x1"],
            "predictions_constant_data": ["x2"],
        }
        fails = check_multiple_attrs(test_dict, inference_data)
        assert not fails
コード例 #23
0
def test_missing_plate(monkeypatch):
    K, N = 3, 1000

    def gmm(data):
        mix_proportions = numpyro.sample("phi", dist.Dirichlet(jnp.ones(K)))
        # plate/to_event is missing here
        cluster_means = numpyro.sample("cluster_means",
                                       dist.Normal(jnp.arange(K), 1.0))

        with numpyro.plate("data", data.shape[0], dim=-1):
            assignments = numpyro.sample("assignments",
                                         dist.Categorical(mix_proportions))
            numpyro.sample("obs",
                           dist.Normal(cluster_means[assignments], 1.0),
                           obs=data)

    true_cluster_means = jnp.array([1.0, 5.0, 10.0])
    true_mix_proportions = jnp.array([0.1, 0.3, 0.6])
    cluster_assignments = dist.Categorical(true_mix_proportions).sample(
        random.PRNGKey(0), (N, ))
    data = dist.Normal(true_cluster_means[cluster_assignments],
                       1.0).sample(random.PRNGKey(1))

    nuts_kernel = NUTS(gmm)
    mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=500)
    with pytest.raises(AssertionError, match="Missing plate statement"):
        mcmc.run(random.PRNGKey(2), data)

    monkeypatch.setattr(numpyro.infer.util, "_validate_model",
                        lambda model_trace: None)
    with pytest.raises(Exception):
        mcmc.run(random.PRNGKey(2), data)
    assert len(_PYRO_STACK) == 0
コード例 #24
0
    def _sample(current_state, seed):
        step_size = jax.tree_map(jax.numpy.ones_like, init_state)
        nuts_kernel = NUTS(
            potential_fn=lambda x: -logp_fn_jax(*x),
            # model=model,
            target_accept_prob=target_accept,
            adapt_step_size=True,
            adapt_mass_matrix=True,
            dense_mass=False,
        )

        pmap_numpyro = MCMC(
            nuts_kernel,
            num_warmup=tune,
            num_samples=draws,
            num_chains=chains,
            postprocess_fn=None,
            chain_method=chain_method,
            progress_bar=progress_bar,
        )

        pmap_numpyro.run(seed,
                         init_params=current_state,
                         extra_fields=("num_steps", ))
        samples = pmap_numpyro.get_samples(group_by_chain=True)
        leapfrogs_taken = pmap_numpyro.get_extra_fields(
            group_by_chain=True)["num_steps"]
        return samples, leapfrogs_taken
コード例 #25
0
ファイル: test_hmc_gibbs.py プロジェクト: mjbajwa/numpyro
def test_linear_model_sigma(kernel_cls,
                            N=90,
                            P=40,
                            sigma=0.07,
                            warmup_steps=500,
                            num_samples=500):
    np.random.seed(1)
    X = np.random.randn(N * P).reshape((N, P))
    XX = np.matmul(np.transpose(X), X)
    Y = X[:, 0] + sigma * np.random.randn(N)
    XY = np.sum(X * Y[:, None], axis=0)

    def model(X, Y):
        N, P = X.shape

        sigma = numpyro.sample("sigma", dist.HalfCauchy(1.0))
        beta = numpyro.sample("beta", dist.Normal(jnp.zeros(P), jnp.ones(P)))
        mean = jnp.sum(beta * X, axis=-1)

        numpyro.sample("obs", dist.Normal(mean, sigma), obs=Y)

    gibbs_fn = partial(_linear_regression_gibbs_fn, X, XX, XY, Y)

    hmc_kernel = kernel_cls(model)
    kernel = HMCGibbs(hmc_kernel, gibbs_fn=gibbs_fn, gibbs_sites=['beta'])
    mcmc = MCMC(kernel, warmup_steps, num_samples, progress_bar=False)

    mcmc.run(random.PRNGKey(0), X, Y)

    beta_mean = np.mean(mcmc.get_samples()['beta'], axis=0)
    assert_allclose(beta_mean, np.array([1.0] + [0.0] * (P - 1)), atol=0.05)

    sigma_mean = np.mean(mcmc.get_samples()['sigma'], axis=0)
    assert_allclose(sigma_mean, sigma, atol=0.25)
コード例 #26
0
ファイル: test_hmc_gibbs.py プロジェクト: mjbajwa/numpyro
def test_hmcecs_normal_normal(kernel_cls, num_block, subsample_size):
    true_loc = jnp.array([0.3, 0.1, 0.9])
    num_warmup, num_samples = 200, 200
    data = true_loc + dist.Normal(jnp.zeros(3, ), jnp.ones(3, )).sample(
        random.PRNGKey(1), (10000, ))

    def model(data, subsample_size):
        mean = numpyro.sample('mean', dist.Normal().expand((3, )).to_event(1))
        with numpyro.plate('batch',
                           data.shape[0],
                           dim=-2,
                           subsample_size=subsample_size):
            sub_data = numpyro.subsample(data, 0)
            numpyro.sample("obs", dist.Normal(mean, 1), obs=sub_data)

    ref_params = {
        'mean':
        true_loc + dist.Normal(true_loc, 5e-2).sample(random.PRNGKey(0))
    }
    proxy_fn = HMCECS.taylor_proxy(ref_params)

    kernel = HMCECS(kernel_cls(model), proxy=proxy_fn)
    mcmc = MCMC(kernel, num_warmup, num_samples)
    mcmc.run(random.PRNGKey(0), data, subsample_size)

    samples = mcmc.get_samples()
    assert_allclose(np.mean(mcmc.get_samples()['mean'], axis=0),
                    true_loc,
                    atol=0.1)
    assert len(samples['mean']) == num_samples
コード例 #27
0
ファイル: hmm_enum.py プロジェクト: while519/numpyro
def main(args):

    model = models[args.model]

    _, fetch = load_dataset(JSB_CHORALES, split='train', shuffle=False)
    lengths, sequences = fetch()
    if args.num_sequences:
        sequences = sequences[0:args.num_sequences]
        lengths = lengths[0:args.num_sequences]

    logger.info('-' * 40)
    logger.info('Training {} on {} sequences'.format(
        model.__name__, len(sequences)))

    # find all the notes that are present at least once in the training set
    present_notes = ((sequences == 1).sum(0).sum(0) > 0)
    # remove notes that are never played (we remove 37/88 notes with default args)
    sequences = sequences[..., present_notes]

    if args.truncate:
        lengths = lengths.clip(0, args.truncate)
        sequences = sequences[:, :args.truncate]

    logger.info('Each sequence has shape {}'.format(sequences[0].shape))
    logger.info('Starting inference...')
    rng_key = random.PRNGKey(2)
    start = time.time()
    kernel = {'nuts': NUTS, 'hmc': HMC}[args.kernel](model)
    mcmc = MCMC(kernel, args.num_warmup, args.num_samples, args.num_chains,
                progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    mcmc.run(rng_key, sequences, lengths, args=args)
    mcmc.print_summary()
    logger.info('\nMCMC elapsed time: {}'.format(time.time() - start))
コード例 #28
0
def test_linear_model_log_sigma(
    kernel_cls, N=100, P=50, sigma=0.11, num_warmup=500, num_samples=500
):
    np.random.seed(0)
    X = np.random.randn(N * P).reshape((N, P))
    XX = np.matmul(np.transpose(X), X)
    Y = X[:, 0] + sigma * np.random.randn(N)
    XY = np.sum(X * Y[:, None], axis=0)

    def model(X, Y):
        N, P = X.shape

        log_sigma = numpyro.sample("log_sigma", dist.Normal(1.0))
        sigma = jnp.exp(log_sigma)
        beta = numpyro.sample("beta", dist.Normal(jnp.zeros(P), jnp.ones(P)))
        mean = jnp.sum(beta * X, axis=-1)
        numpyro.deterministic("mean", mean)

        numpyro.sample("obs", dist.Normal(mean, sigma), obs=Y)

    gibbs_fn = partial(_linear_regression_gibbs_fn, X, XX, XY, Y)

    hmc_kernel = kernel_cls(model)
    kernel = HMCGibbs(hmc_kernel, gibbs_fn=gibbs_fn, gibbs_sites=["beta"])
    mcmc = MCMC(
        kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False
    )

    mcmc.run(random.PRNGKey(0), X, Y)

    beta_mean = np.mean(mcmc.get_samples()["beta"], axis=0)
    assert_allclose(beta_mean, np.array([1.0] + [0.0] * (P - 1)), atol=0.05)

    sigma_mean = np.exp(np.mean(mcmc.get_samples()["log_sigma"], axis=0))
    assert_allclose(sigma_mean, sigma, atol=0.25)
コード例 #29
0
def test_estimate_likelihood(kernel_cls):
    data_key, tr_key, sub_key, rng_key = random.split(random.PRNGKey(0), 4)
    ref_params = jnp.array([0.1, 0.5, -0.2])
    sigma = 0.1
    data = ref_params + dist.Normal(jnp.zeros(3), jnp.ones(3)).sample(
        data_key, (10_000,)
    )
    n, _ = data.shape
    num_warmup = 200
    num_samples = 200
    num_blocks = 20

    def model(data):
        mean = numpyro.sample(
            "mean", dist.Normal(ref_params, jnp.ones_like(ref_params))
        )
        with numpyro.plate("N", data.shape[0], subsample_size=100, dim=-2) as idx:
            numpyro.sample("obs", dist.Normal(mean, sigma), obs=data[idx])

    proxy_fn = HMCECS.taylor_proxy({"mean": ref_params})
    kernel = HMCECS(kernel_cls(model), proxy=proxy_fn, num_blocks=num_blocks)
    mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples)

    mcmc.run(random.PRNGKey(0), data, extra_fields=["hmc_state.potential_energy"])

    pes = mcmc.get_extra_fields()["hmc_state.potential_energy"]
    samples = mcmc.get_samples()
    pes_full = vmap(
        lambda sample: log_density(
            model, (data,), {}, {**sample, **{"N": jnp.arange(n)}}
        )[0]
    )(samples)

    assert jnp.var(jnp.exp(-pes - pes_full)) < 1.0
コード例 #30
0
ファイル: funnel.py プロジェクト: cnheider/numpyro
def run_inference(model, args, rng_key):
    kernel = NUTS(model)
    mcmc = MCMC(
        kernel,
        args.num_warmup,
        args.num_samples,
        num_chains=args.num_chains,
        progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    mcmc.run(rng_key)
    mcmc.print_summary()
    return mcmc.get_samples()