Exemple #1
0
def test_logistic_regression_x64(kernel_cls):
    N, dim = 3000, 3
    if kernel_cls is SA:
        warmup_steps, num_samples = (100000, 100000)
    elif kernel_cls is BarkerMH:
        warmup_steps, num_samples = (2000, 12000)
    else:
        warmup_steps, num_samples = (1000, 8000)
    data = random.normal(random.PRNGKey(0), (N, dim))
    true_coefs = jnp.arange(1., dim + 1.)
    logits = jnp.sum(true_coefs * data, axis=-1)
    labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

    def model(labels):
        coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(dim), jnp.ones(dim)))
        logits = numpyro.deterministic('logits', jnp.sum(coefs * data, axis=-1))
        return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels)

    if kernel_cls is SA:
        kernel = SA(model=model, adapt_state_size=9)
    elif kernel_cls is BarkerMH:
        kernel = BarkerMH(model=model)
    else:
        kernel = kernel_cls(model=model, trajectory_length=8, find_heuristic_step_size=True)
    mcmc = MCMC(kernel, warmup_steps, num_samples, progress_bar=False)
    mcmc.run(random.PRNGKey(2), labels)
    mcmc.print_summary()
    samples = mcmc.get_samples()
    assert samples['logits'].shape == (num_samples, N)
    # those coefficients are found by doing MAP inference using AutoDelta
    expected_coefs = jnp.array([0.97, 2.05, 3.18])
    assert_allclose(jnp.mean(samples['coefs'], 0), expected_coefs, atol=0.1)

    if 'JAX_ENABLE_X64' in os.environ:
        assert samples['coefs'].dtype == jnp.float64
Exemple #2
0
def run_inference(model, args, rng_key):
    kernel = NUTS(model)
    mcmc = MCMC(kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains,
                progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    mcmc.run(rng_key)
    mcmc.print_summary()
    return mcmc.get_samples()
Exemple #3
0
def sample(model,
           num_samples,
           num_warmup,
           num_chains,
           seed=0,
           chain_method="parallel",
           summary=True,
           **kwargs):
    """Run the No-U-Turn sampler
    """
    rng_key = random.PRNGKey(seed)
    kernel = NUTS(model)
    # Note: sampling more than one chain doesn't show a progress bar
    mcmc = MCMC(kernel,
                num_warmup,
                num_samples,
                num_chains,
                chain_method=chain_method)
    mcmc.run(rng_key, **kwargs)

    if summary:
        mcmc.print_summary()

    # Return a fitted MCMC object
    return mcmc
Exemple #4
0
def run_hmcecs(hmcecs_key, args, data, obs, inner_kernel):
    svi_key, mcmc_key = random.split(hmcecs_key)

    # find reference parameters for second order taylor expansion to estimate likelihood (taylor_proxy)
    optimizer = numpyro.optim.Adam(step_size=1e-3)
    guide = autoguide.AutoDelta(model)
    svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
    svi_result = svi.run(svi_key, args.num_svi_steps, data, obs,
                         args.subsample_size)
    params, losses = svi_result.params, svi_result.losses
    ref_params = {"theta": params["theta_auto_loc"]}

    # taylor proxy estimates log likelihood (ll) by
    # taylor_expansion(ll, theta_curr) +
    #     sum_{i in subsample} ll_i(theta_curr) - taylor_expansion(ll_i, theta_curr) around ref_params
    proxy = HMCECS.taylor_proxy(ref_params)

    kernel = HMCECS(inner_kernel, num_blocks=args.num_blocks, proxy=proxy)
    mcmc = MCMC(kernel,
                num_warmup=args.num_warmup,
                num_samples=args.num_samples)

    mcmc.run(mcmc_key, data, obs, args.subsample_size)
    mcmc.print_summary()
    return losses, mcmc.get_samples()
def sample_model(rng_key,
                 model,
                 model_args_dict,
                 num_warmup=500,
                 num_samples=500,
                 num_chains=1):

    kernel = NUTS(model)

    mcmc = MCMC(kernel, num_warmup, num_samples, num_chains, progress_bar=True)

    mcmc.run(rng_key, **model_args_dict)

    mcmc.print_summary()

    # divergences = mcmc.get_extra_fields()["diverging"]

    samples = mcmc.get_samples()
    # samples['divergences'] = divergences

    if not 'b_condition' in samples:
        bC = numpyro.infer.Predictive(model, samples).get_samples(
            rng_key, **model_args_dict)

        samples['b_condition'] = bC

    return samples
Exemple #6
0
def run_inference(model, inputs, method=None):
    if method is None:
        # NUTS
        num_samples = 5000
        logger.info('NUTS sampling')
        kernel = NUTS(model)
        mcmc = MCMC(kernel, num_warmup=300, num_samples=num_samples)
        rng_key = random.PRNGKey(0)
        mcmc.run(rng_key, **inputs, extra_fields=('potential_energy', ))
        logger.info(r'MCMC summary for: {}'.format(model.__name__))
        mcmc.print_summary(exclude_deterministic=False)
        samples = mcmc.get_samples()
    else:
        #SVI
        logger.info('Guide generation...')
        rng_key = random.PRNGKey(0)
        guide = AutoDiagonalNormal(model=model)
        logger.info('Optimizer generation...')
        optim = Adam(0.05)
        logger.info('SVI generation...')
        svi = SVI(model, guide, optim, AutoContinuousELBO(), **inputs)
        init_state = svi.init(rng_key)
        logger.info('Scan...')
        state, loss = lax.scan(lambda x, i: svi.update(x), init_state,
                               np.zeros(2000))
        params = svi.get_params(state)
        samples = guide.sample_posterior(random.PRNGKey(1), params, (1000, ))
        logger.info(r'SVI summary for: {}'.format(model.__name__))
        numpyro.diagnostics.print_summary(samples,
                                          prob=0.90,
                                          group_by_chain=False)
    return samples
Exemple #7
0
    def infer(self,
              num_warmup=1000,
              num_samples=1000,
              num_chains=1,
              rng_key=PRNGKey(1),
              **args):
        '''Fit using MCMC'''

        # Start from this source of randomness. We will split keys for subsequent operations.
        rng_key = PRNGKey(0)
        rng_key, rng_key_ = split(rng_key)

        args = dict(self.args, **args)

        #kernel = NUTS(self, init_strategy = numpyro.infer.util.init_to_median())
        kernel = NUTS(self)

        mcmc = MCMC(kernel,
                    num_warmup=num_warmup,
                    num_samples=num_samples,
                    num_chains=num_chains)

        mcmc.run(rng_key, **self.obs, **args)
        mcmc.print_summary()

        self.mcmc = mcmc
        self.mcmc_samples = mcmc.get_samples()

        return self.mcmc_samples
Exemple #8
0
def run_inference(model):
    kernel = NUTS(model)
    rng_key = random.PRNGKey(0)
    mcmc = MCMC(kernel, num_warmup = 500, num_samples = 500, num_chains = 1)
    mcmc.run(rng_key)
    mcmc.print_summary(exclude_deterministic=False)
    return mcmc.get_samples()
Exemple #9
0
def main(args):

    model = models[args.model]

    _, fetch = load_dataset(JSB_CHORALES, split='train', shuffle=False)
    lengths, sequences = fetch()
    if args.num_sequences:
        sequences = sequences[0:args.num_sequences]
        lengths = lengths[0:args.num_sequences]

    logger.info('-' * 40)
    logger.info('Training {} on {} sequences'.format(
        model.__name__, len(sequences)))

    # find all the notes that are present at least once in the training set
    present_notes = ((sequences == 1).sum(0).sum(0) > 0)
    # remove notes that are never played (we remove 37/88 notes with default args)
    sequences = sequences[..., present_notes]

    if args.truncate:
        lengths = lengths.clip(0, args.truncate)
        sequences = sequences[:, :args.truncate]

    logger.info('Each sequence has shape {}'.format(sequences[0].shape))
    logger.info('Starting inference...')
    rng_key = random.PRNGKey(2)
    start = time.time()
    kernel = {'nuts': NUTS, 'hmc': HMC}[args.kernel](model)
    mcmc = MCMC(kernel, args.num_warmup, args.num_samples, args.num_chains,
                progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    mcmc.run(rng_key, sequences, lengths, args=args)
    mcmc.print_summary()
    logger.info('\nMCMC elapsed time: {}'.format(time.time() - start))
Exemple #10
0
def run_inference(model, args, rng_key, X, Y):
    start = time.time()
    # demonstrate how to use different HMC initialization strategies
    if args.init_strategy == "value":
        init_strategy = init_to_value(values={
            "kernel_var": 1.0,
            "kernel_noise": 0.05,
            "kernel_length": 0.5
        })
    elif args.init_strategy == "median":
        init_strategy = init_to_median(num_samples=10)
    elif args.init_strategy == "feasible":
        init_strategy = init_to_feasible()
    elif args.init_strategy == "sample":
        init_strategy = init_to_sample()
    elif args.init_strategy == "uniform":
        init_strategy = init_to_uniform(radius=1)
    kernel = NUTS(model, init_strategy=init_strategy)
    mcmc = MCMC(
        kernel,
        args.num_warmup,
        args.num_samples,
        num_chains=args.num_chains,
        progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    mcmc.run(rng_key, X, Y)
    mcmc.print_summary()
    print('\nMCMC elapsed time:', time.time() - start)
    return mcmc.get_samples()
    def fit(self, df, iter=500, seed=42, **kwargs):
        teams = sorted(list(set(df["home_team"]) | set(df["away_team"])))
        home_team = df["home_team"].values
        away_team = df["away_team"].values
        home_goals = df["home_goals"].values
        away_goals = df["away_goals"].values
        gameweek = ((df["date"] - df["date"].min()).dt.days // 7).values

        self.team_to_index = {team: i for i, team in enumerate(teams)}
        self.index_to_team = {
            value: key
            for key, value in self.team_to_index.items()
        }
        self.n_teams = len(teams)
        self.min_date = df["date"].min()

        conditioned_model = condition(self.model,
                                      param_map={
                                          "home_goals": home_goals,
                                          "away_goals": away_goals
                                      })
        nuts_kernel = NUTS(conditioned_model)
        mcmc = MCMC(nuts_kernel,
                    num_warmup=iter // 2,
                    num_samples=iter,
                    **kwargs)
        rng_key = random.PRNGKey(seed)
        mcmc.run(rng_key, home_team, away_team, gameweek)

        self.samples = mcmc.get_samples()
        mcmc.print_summary()
        return self
Exemple #12
0
def test_logistic_regression():
    from tensorflow_probability.substrates.jax import distributions as tfd

    N, dim = 3000, 3
    num_warmup, num_samples = (1000, 1000)
    data = random.normal(random.PRNGKey(0), (N, dim))
    true_coefs = jnp.arange(1.0, dim + 1.0)
    logits = jnp.sum(true_coefs * data, axis=-1)
    labels = tfd.Bernoulli(logits=logits).sample(seed=random.PRNGKey(1))

    def model(labels):
        coefs = numpyro.sample("coefs",
                               tfd.Normal(jnp.zeros(dim), jnp.ones(dim)))
        logits = numpyro.deterministic("logits", jnp.sum(coefs * data,
                                                         axis=-1))
        return numpyro.sample("obs", tfd.Bernoulli(logits=logits), obs=labels)

    kernel = NUTS(model)
    mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples)
    mcmc.run(random.PRNGKey(2), labels)
    mcmc.print_summary()
    samples = mcmc.get_samples()
    assert samples["logits"].shape == (num_samples, N)
    expected_coefs = jnp.array([0.97, 2.05, 3.18])
    assert_allclose(jnp.mean(samples["coefs"], 0), expected_coefs, atol=0.22)
Exemple #13
0
def test_logistic_regression():
    from numpyro.contrib.tfp import distributions as dist

    N, dim = 3000, 3
    num_warmup, num_samples = (1000, 1000)
    data = random.normal(random.PRNGKey(0), (N, dim))
    true_coefs = jnp.arange(1., dim + 1.)
    logits = jnp.sum(true_coefs * data, axis=-1)
    labels = dist.Bernoulli(logits=logits)(rng_key=random.PRNGKey(1))

    def model(labels):
        coefs = numpyro.sample('coefs',
                               dist.Normal(jnp.zeros(dim), jnp.ones(dim)))
        logits = numpyro.deterministic('logits', jnp.sum(coefs * data,
                                                         axis=-1))
        return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels)

    kernel = NUTS(model)
    mcmc = MCMC(kernel, num_warmup, num_samples)
    mcmc.run(random.PRNGKey(2), labels)
    mcmc.print_summary()
    samples = mcmc.get_samples()
    assert samples['logits'].shape == (num_samples, N)
    expected_coefs = jnp.array([0.97, 2.05, 3.18])
    assert_allclose(jnp.mean(samples['coefs'], 0), expected_coefs, atol=0.22)
Exemple #14
0
def test_unnormalized_normal_x64(kernel_cls, dense_mass):
    true_mean, true_std = 1.0, 0.5
    num_warmup, num_samples = (100000, 100000) if kernel_cls is SA else (1000,
                                                                         8000)

    def potential_fn(z):
        return 0.5 * jnp.sum(((z - true_mean) / true_std)**2)

    init_params = jnp.array(0.0)
    if kernel_cls is SA:
        kernel = SA(potential_fn=potential_fn, dense_mass=dense_mass)
    elif kernel_cls is BarkerMH:
        kernel = SA(potential_fn=potential_fn, dense_mass=dense_mass)
    else:
        kernel = kernel_cls(potential_fn=potential_fn,
                            trajectory_length=8,
                            dense_mass=dense_mass)
    mcmc = MCMC(kernel,
                num_warmup=num_warmup,
                num_samples=num_samples,
                progress_bar=False)
    mcmc.run(random.PRNGKey(0), init_params=init_params)
    mcmc.print_summary()
    hmc_states = mcmc.get_samples()
    assert_allclose(jnp.mean(hmc_states), true_mean, rtol=0.07)
    assert_allclose(jnp.std(hmc_states), true_std, rtol=0.07)

    if "JAX_ENABLE_X64" in os.environ:
        assert hmc_states.dtype == jnp.float64
Exemple #15
0
def inference(
    model: Callable,
    num_categories: int,
    num_words: int,
    supervised_categories: jnp.ndarray,
    supervised_words: jnp.ndarray,
    unsupervised_words: jnp.ndarray,
    rng_key: np.ndarray,
    *,
    num_warmup: int = 500,
    num_samples: int = 1000,
    num_chains: int = 1,
    verbose: bool = True,
) -> Dict[str, jnp.ndarray]:

    kernel = NUTS(model)
    mcmc = MCMC(kernel,
                num_warmup=num_warmup,
                num_samples=num_samples,
                num_chains=num_chains)
    mcmc.run(
        rng_key,
        num_categories,
        num_words,
        supervised_categories,
        supervised_words,
        unsupervised_words,
    )
    if verbose:
        mcmc.print_summary()

    return mcmc.get_samples()
Exemple #16
0
def main(args):
    _, fetch = load_dataset(LYNXHARE, shuffle=False)
    year, data = fetch()  # data is in hare -> lynx order

    # use dense_mass for better mixing rate
    mcmc = MCMC(NUTS(model, dense_mass=True),
                args.num_warmup, args.num_samples, num_chains=args.num_chains,
                progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    mcmc.run(PRNGKey(1), N=data.shape[0], y=jnp.log(data))
    mcmc.print_summary()

    # predict populations
    y_pred = Predictive(model, mcmc.get_samples())(PRNGKey(2), data.shape[0])["y"]
    pop_pred = jnp.exp(y_pred)
    mu, pi = jnp.mean(pop_pred, 0), jnp.percentile(pop_pred, (10, 90), 0)
    plt.plot(year, data[:, 0], "ko", mfc="none", ms=4, label="true hare", alpha=0.67)
    plt.plot(year, data[:, 1], "bx", label="true lynx")
    plt.plot(year, mu[:, 0], "k-.", label="pred hare", lw=1, alpha=0.67)
    plt.plot(year, mu[:, 1], "b--", label="pred lynx")
    plt.fill_between(year, pi[0, :, 0], pi[1, :, 0], color="k", alpha=0.2)
    plt.fill_between(year, pi[0, :, 1], pi[1, :, 1], color="b", alpha=0.3)
    plt.gca().set(ylim=(0, 160), xlabel="year", ylabel="population (in thousands)")
    plt.title("Posterior predictive (80% CI) with predator-prey pattern.")
    plt.legend()

    plt.savefig("ode_plot.pdf")
    plt.tight_layout()
Exemple #17
0
def test_beta_bernoulli_x64(kernel_cls):
    warmup_steps, num_samples = (100000,
                                 100000) if kernel_cls is SA else (500, 20000)

    def model(data):
        alpha = np.array([1.1, 1.1])
        beta = np.array([1.1, 1.1])
        p_latent = numpyro.sample('p_latent', dist.Beta(alpha, beta))
        numpyro.sample('obs', dist.Bernoulli(p_latent), obs=data)
        return p_latent

    true_probs = np.array([0.9, 0.1])
    data = dist.Bernoulli(true_probs).sample(random.PRNGKey(1), (1000, 2))
    if kernel_cls is SA:
        kernel = SA(model=model)
    else:
        kernel = kernel_cls(model=model, trajectory_length=1.)
    mcmc = MCMC(kernel,
                num_warmup=warmup_steps,
                num_samples=num_samples,
                progress_bar=False)
    mcmc.run(random.PRNGKey(2), data)
    mcmc.print_summary()
    samples = mcmc.get_samples()
    assert_allclose(np.mean(samples['p_latent'], 0), true_probs, atol=0.05)

    if 'JAX_ENABLE_X64' in os.environ:
        assert samples['p_latent'].dtype == np.float64
Exemple #18
0
def run_hmc(mcmc_key, args, data, obs, kernel):
    mcmc = MCMC(kernel,
                num_warmup=args.num_warmup,
                num_samples=args.num_samples)
    mcmc.run(mcmc_key, data, obs, None)
    mcmc.print_summary()
    return mcmc.get_samples()
Exemple #19
0
def main(args):
    annotators, annotations = get_data()
    model = NAME_TO_MODEL[args.model]
    data = ((annotations, ) if model in [multinomial, item_difficulty] else
            (annotators, annotations))

    mcmc = MCMC(
        NUTS(model),
        num_warmup=args.num_warmup,
        num_samples=args.num_samples,
        num_chains=args.num_chains,
        progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
    )
    mcmc.run(random.PRNGKey(0), *data)
    mcmc.print_summary()

    posterior_samples = mcmc.get_samples()
    predictive = Predictive(model, posterior_samples, infer_discrete=True)
    discrete_samples = predictive(random.PRNGKey(1), *data)

    item_class = vmap(lambda x: jnp.bincount(x, length=4),
                      in_axes=1)(discrete_samples["c"].squeeze(-1))
    print("Histogram of the predicted class of each item:")
    row_format = "{:>10}" * 5
    print(row_format.format("", *["c={}".format(i) for i in range(4)]))
    for i, row in enumerate(item_class):
        print(row_format.format(f"item[{i}]", *row))
Exemple #20
0
def run_nuts(mcmc_key, args, X, Y):
    mcmc = MCMC(NUTS(model),
                num_warmup=args.num_warmup,
                num_samples=args.num_samples)
    mcmc.run(mcmc_key, X, Y)
    mcmc.print_summary()
    return mcmc.get_samples()
Exemple #21
0
def test_logistic_regression_x64(kernel_cls):
    N, dim = 3000, 3
    warmup_steps, num_samples = (100000,
                                 100000) if kernel_cls is SA else (1000, 8000)
    data = random.normal(random.PRNGKey(0), (N, dim))
    true_coefs = np.arange(1., dim + 1.)
    logits = np.sum(true_coefs * data, axis=-1)
    labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

    def model(labels):
        coefs = numpyro.sample('coefs', dist.Normal(np.zeros(dim),
                                                    np.ones(dim)))
        logits = numpyro.deterministic('logits', np.sum(coefs * data, axis=-1))
        return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels)

    if kernel_cls is SA:
        kernel = SA(model=model, adapt_state_size=9)
    else:
        kernel = kernel_cls(model=model, trajectory_length=8)
    mcmc = MCMC(kernel, warmup_steps, num_samples, progress_bar=False)
    mcmc.run(random.PRNGKey(2), labels)
    mcmc.print_summary()
    samples = mcmc.get_samples()
    assert samples['logits'].shape == (num_samples, N)
    assert_allclose(np.mean(samples['coefs'], 0), true_coefs, atol=0.22)

    if 'JAX_ENABLE_X64' in os.environ:
        assert samples['coefs'].dtype == np.float64
Exemple #22
0
def run_inference(args, data):
    print("=== Performing Nested Sampling ===")
    ns = NestedSampler(model)
    ns.run(random.PRNGKey(0), **data, enum=args.enum)
    # TODO: Remove this condition when jaxns is compatible with the latest jax version.
    if jax.__version__ < "0.2.21":
        ns.print_summary()
    # samples obtained from nested sampler are weighted, so
    # we need to provide random key to resample from those weighted samples
    ns_samples = ns.get_samples(random.PRNGKey(1),
                                num_samples=args.num_samples)

    print("\n=== Performing MCMC Sampling ===")
    if args.enum:
        mcmc = MCMC(NUTS(model),
                    num_warmup=args.num_warmup,
                    num_samples=args.num_samples)
    else:
        mcmc = MCMC(
            DiscreteHMCGibbs(NUTS(model)),
            num_warmup=args.num_warmup,
            num_samples=args.num_samples,
        )
    mcmc.run(random.PRNGKey(2), **data, enum=args.enum)
    mcmc.print_summary()
    mcmc_samples = mcmc.get_samples()

    return ns_samples["x"], mcmc_samples["x"]
Exemple #23
0
class NutsHandler(Handler):
    def __init__(
        self,
        model,
        posterior=None,
        num_warmup=2000,
        num_samples=10000,
        num_chains=1,
        key=0,
        *args,
        **kwargs,
    ):
        self.model = model
        self.rng_key, self.rng_key_ = random.split(random.PRNGKey(key))

        if posterior is not None:
            self.mcmc = posterior
            self.posterior = self.mcmc.get_samples()
        else:
            self.kernel = NUTS(model, **kwargs)
            self.mcmc = MCMC(self.kernel,
                             num_warmup,
                             num_samples,
                             num_chains=num_chains)

    def _select(self, which):
        assert which in [
            "prior",
            "posterior",
            "posterior_predictive",
        ], "Please select from 'prior', 'posterior' or 'posterior_predictive'."
        assert hasattr(self,
                       which), f"NutsHandler did not compute the {which} yet."
        return getattr(self, which)

    def get_prior(self, *args, **kwargs):
        predictive = Predictive(self.model, num_samples=self.mcmc.num_samples)
        self.prior = predictive(self.rng_key_, *args, **kwargs)

    def get_posterior_predictive(self, *args, **kwargs):
        predictive = Predictive(self.model, self.posterior, **kwargs)
        self.posterior_predictive = predictive(self.rng_key_, *args)

    def fit(self, *args, **kwargs):
        self.mcmc.run(self.rng_key_, *args, **kwargs)
        self.posterior = self.mcmc.get_samples()

    def summary(self, *args, **kwargs):
        self.mcmc.print_summary(*args, **kwargs)

    def dump(self, path):
        with open(path, "wb") as f:
            dill.dump(self.mcmc, f)

    @staticmethod
    def from_dump(model, path):
        with open(path, "rb") as f:
            posterior = dill.load(f)
        return NutsHandler(model, posterior=posterior)
Exemple #24
0
def run_inference(model, args, rng_key, X, Y, D_H):
    start = time.time()
    kernel = NUTS(model)
    mcmc = MCMC(kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains)
    mcmc.run(rng_key, X, Y, D_H)
    mcmc.print_summary()
    print('\nMCMC elapsed time:', time.time() - start)
    return mcmc.get_samples()
Exemple #25
0
def run_inference(model, args, rng_key, X, Y):
    start = time.time()
    kernel = NUTS(model)
    mcmc = MCMC(kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains,
                progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    mcmc.run(rng_key, X, Y)
    mcmc.print_summary()
    print('\nMCMC elapsed time:', time.time() - start)
    return mcmc.get_samples()
Exemple #26
0
def test_random_module_mcmc(backend, init):

    if backend == "flax":
        import flax

        linear_module = flax.linen.Dense(features=1)
        bias_name = "bias"
        weight_name = "kernel"
        random_module = random_flax_module
        kwargs_name = "inputs"
    elif backend == "haiku":
        import haiku as hk

        linear_module = hk.transform(lambda x: hk.Linear(1)(x))
        bias_name = "linear.b"
        weight_name = "linear.w"
        random_module = random_haiku_module
        kwargs_name = "x"

    N, dim = 3000, 3
    num_warmup, num_samples = (1000, 1000)
    data = random.normal(random.PRNGKey(0), (N, dim))
    true_coefs = np.arange(1.0, dim + 1.0)
    logits = np.sum(true_coefs * data, axis=-1)
    labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

    if init == "shape":
        kwargs = {"input_shape": (3,)}
    elif init == "kwargs":
        kwargs = {kwargs_name: data}

    def model(data, labels):
        nn = random_module(
            "nn",
            linear_module,
            {bias_name: dist.Cauchy(), weight_name: dist.Normal()},
            **kwargs
        )
        logits = nn(data).squeeze(-1)
        numpyro.sample("y", dist.Bernoulli(logits=logits), obs=labels)

    kernel = NUTS(model=model)
    mcmc = MCMC(
        kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False
    )
    mcmc.run(random.PRNGKey(2), data, labels)
    mcmc.print_summary()
    samples = mcmc.get_samples()
    assert set(samples.keys()) == {
        "nn/{}".format(bias_name),
        "nn/{}".format(weight_name),
    }
    assert_allclose(
        np.mean(samples["nn/{}".format(weight_name)].squeeze(-1), 0),
        true_coefs,
        atol=0.22,
    )
Exemple #27
0
def run_inference(model, capture_history, sex, rng_key, args):
    if args.algo == "NUTS":
        kernel = NUTS(model)
    elif args.algo == "HMC":
        kernel = HMC(model)
    mcmc = MCMC(kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains,
                progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    mcmc.run(rng_key, capture_history, sex)
    mcmc.print_summary()
    return mcmc.get_samples()
Exemple #28
0
def benchmark_hmc(args, features, labels):
    step_size = np.sqrt(0.5 / features.shape[0])
    trajectory_length = step_size * args.num_steps
    rng_key = random.PRNGKey(1)
    start = time.time()
    kernel = NUTS(model, trajectory_length=trajectory_length)
    mcmc = MCMC(kernel, 0, args.num_samples)
    mcmc.run(rng_key, features, labels)
    mcmc.print_summary()
    print('\nMCMC elapsed time:', time.time() - start)
Exemple #29
0
def run_hmc(rng_key, model, data, num_mix_comp, args, bvm_init_locs):
    kernel = NUTS(model,
                  init_strategy=init_to_value(values=bvm_init_locs),
                  max_tree_depth=7)
    mcmc = MCMC(kernel,
                num_samples=args.num_samples,
                num_warmup=args.num_warmup)
    mcmc.run(rng_key, data, len(data), num_mix_comp)
    mcmc.print_summary()
    post_samples = mcmc.get_samples()
    return post_samples
Exemple #30
0
def run_mcmc(model, args, X, Y):
    kernel = NUTS(model)
    mcmc = MCMC(
        kernel,
        num_warmup=args.num_warmup,
        num_samples=args.num_samples,
        num_chains=args.num_chains,
        progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
    )
    mcmc.run(PRNGKey(1), X, Y)
    mcmc.print_summary()
    return mcmc.get_samples()