コード例 #1
0
def test_diverging(kernel_cls, adapt_step_size):
    data = random.normal(random.PRNGKey(0), (1000, ))

    def model(data):
        loc = numpyro.sample('loc', dist.Normal(0., 1.))
        numpyro.sample('obs', dist.Normal(loc, 1), obs=data)

    kernel = kernel_cls(model,
                        step_size=10.,
                        adapt_step_size=adapt_step_size,
                        adapt_mass_matrix=False)
    num_warmup = num_samples = 1000
    mcmc = MCMC(kernel, num_warmup, num_samples)
    mcmc.warmup(random.PRNGKey(1),
                data,
                extra_fields=['diverging'],
                collect_warmup=True)
    warmup_divergences = mcmc.get_extra_fields()['diverging'].sum()
    mcmc.run(random.PRNGKey(2), data, extra_fields=['diverging'])
    num_divergences = warmup_divergences + mcmc.get_extra_fields(
    )['diverging'].sum()
    if adapt_step_size:
        assert num_divergences <= num_warmup
    else:
        assert_allclose(num_divergences, num_warmup + num_samples)
コード例 #2
0
def test_chain_jit_args_smoke(chain_method, compile_args):
    def model(data):
        concentration = jnp.array([1.0, 1.0, 1.0])
        p_latent = numpyro.sample("p_latent", dist.Dirichlet(concentration))
        numpyro.sample("obs", dist.Categorical(p_latent), obs=data)
        return p_latent

    data1 = dist.Categorical(jnp.array([0.1, 0.6,
                                        0.3])).sample(random.PRNGKey(1),
                                                      (50, ))
    data2 = dist.Categorical(jnp.array([0.2, 0.4,
                                        0.4])).sample(random.PRNGKey(1),
                                                      (50, ))
    kernel = NUTS(model)
    mcmc = MCMC(
        kernel,
        num_warmup=2,
        num_samples=5,
        num_chains=2,
        chain_method=chain_method,
        jit_model_args=compile_args,
    )
    mcmc.warmup(random.PRNGKey(0), data1)
    mcmc.run(random.PRNGKey(1), data1)
    # this should be fast if jit_model_args=True
    mcmc.run(random.PRNGKey(2), data2)
コード例 #3
0
def test_mcmc_progbar():
    true_mean, true_std = 1., 2.
    num_warmup, num_samples = 10, 10

    def model(data):
        mean = numpyro.param('mean', 0.)
        std = numpyro.param('std', 1., constraint=constraints.positive)
        return numpyro.sample('obs', dist.Normal(mean, std), obs=data)

    data = dist.Normal(true_mean, true_std).sample(random.PRNGKey(1), (2000, ))
    kernel = NUTS(model=model)
    mcmc = MCMC(kernel, num_warmup, num_samples)
    mcmc.warmup(random.PRNGKey(2), data)
    mcmc.run(random.PRNGKey(3), data)
    mcmc1 = MCMC(kernel, num_warmup, num_samples, progress_bar=False)
    mcmc1.run(random.PRNGKey(2), data)

    with pytest.raises(AssertionError):
        check_close(mcmc1.get_samples(),
                    mcmc.get_samples(),
                    atol=1e-4,
                    rtol=1e-4)
    mcmc1.warmup(random.PRNGKey(2), data)
    mcmc1.run(random.PRNGKey(3), data)
    check_close(mcmc1.get_samples(), mcmc.get_samples(), atol=1e-4, rtol=1e-4)
    check_close(mcmc1._warmup_state, mcmc._warmup_state, atol=1e-4, rtol=1e-4)
コード例 #4
0
def test_uniform_normal(forward_mode_differentiation):
    true_coef = 0.9
    num_warmup, num_samples = 1000, 1000

    def model(data):
        alpha = numpyro.sample("alpha", dist.Uniform(0, 1))
        with numpyro.handlers.reparam(config={"loc": TransformReparam()}):
            loc = numpyro.sample(
                "loc",
                dist.TransformedDistribution(dist.Uniform(0, 1),
                                             AffineTransform(0, alpha)),
            )
        numpyro.sample("obs", dist.Normal(loc, 0.1), obs=data)

    data = true_coef + random.normal(random.PRNGKey(0), (1000, ))
    kernel = NUTS(model=model,
                  forward_mode_differentiation=forward_mode_differentiation)
    mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples)
    mcmc.warmup(random.PRNGKey(2), data, collect_warmup=True)
    assert mcmc.post_warmup_state is not None
    warmup_samples = mcmc.get_samples()
    mcmc.run(random.PRNGKey(3), data)
    samples = mcmc.get_samples()
    assert len(warmup_samples["loc"]) == num_warmup
    assert len(samples["loc"]) == num_samples
    assert_allclose(jnp.mean(samples["loc"], 0), true_coef, atol=0.05)

    mcmc.post_warmup_state = mcmc.last_state
    mcmc.run(random.PRNGKey(3), data)
    samples = mcmc.get_samples()
    assert len(samples["loc"]) == num_samples
    assert_allclose(jnp.mean(samples["loc"], 0), true_coef, atol=0.05)
コード例 #5
0
def test_mcmc_progbar():
    true_mean, true_std = 1.0, 2.0
    num_warmup, num_samples = 10, 10

    def model(data):
        mean = numpyro.sample("mean", dist.Normal(0, 1).mask(False))
        std = numpyro.sample("std", dist.LogNormal(0, 1).mask(False))
        return numpyro.sample("obs", dist.Normal(mean, std), obs=data)

    data = dist.Normal(true_mean, true_std).sample(random.PRNGKey(1), (2000, ))
    kernel = NUTS(model=model)
    mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples)
    mcmc.warmup(random.PRNGKey(2), data)
    mcmc.run(random.PRNGKey(3), data)
    mcmc1 = MCMC(kernel,
                 num_warmup=num_warmup,
                 num_samples=num_samples,
                 progress_bar=False)
    mcmc1.run(random.PRNGKey(2), data)

    with pytest.raises(AssertionError):
        check_close(mcmc1.get_samples(),
                    mcmc.get_samples(),
                    atol=1e-4,
                    rtol=1e-4)
    mcmc1.warmup(random.PRNGKey(2), data)
    mcmc1.run(random.PRNGKey(3), data)
    check_close(mcmc1.get_samples(), mcmc.get_samples(), atol=1e-4, rtol=1e-4)
    check_close(mcmc1.post_warmup_state,
                mcmc.post_warmup_state,
                atol=1e-4,
                rtol=1e-4)
コード例 #6
0
def test_compile_warmup_run(num_chains, chain_method, progress_bar):
    def model():
        numpyro.sample("x", dist.Normal(0, 1))

    if num_chains == 1 and chain_method in ['sequential', 'vectorized']:
        pytest.skip('duplicated test')
    if num_chains > 1 and chain_method == 'parallel':
        pytest.skip('duplicated test')

    rng_key = random.PRNGKey(0)
    num_samples = 10
    mcmc = MCMC(NUTS(model), 10, num_samples, num_chains,
                chain_method=chain_method, progress_bar=progress_bar)

    mcmc.run(rng_key)
    expected_samples = mcmc.get_samples()["x"]

    mcmc._compile(rng_key)
    # no delay after compiling
    mcmc.warmup(rng_key)
    mcmc.run(mcmc._warmup_state.rng_key)
    actual_samples = mcmc.get_samples()["x"]

    assert_allclose(actual_samples, expected_samples)

    # test for reproducible
    if num_chains > 1:
        mcmc = MCMC(NUTS(model), 10, num_samples, 1, progress_bar=progress_bar)
        rng_key = random.split(rng_key)[0]
        mcmc.run(rng_key)
        first_chain_samples = mcmc.get_samples()["x"]
        assert_allclose(actual_samples[:num_samples], first_chain_samples, atol=1e-5)
コード例 #7
0
def test_uniform_normal():
    true_coef = 0.9
    num_warmup, num_samples = 1000, 1000

    def model(data):
        alpha = numpyro.sample('alpha', dist.Uniform(0, 1))
        with numpyro.handlers.reparam(config={'loc': TransformReparam()}):
            loc = numpyro.sample('loc', dist.Uniform(0, alpha))
        numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)

    data = true_coef + random.normal(random.PRNGKey(0), (1000,))
    kernel = NUTS(model=model)
    mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples)
    mcmc.warmup(random.PRNGKey(2), data, collect_warmup=True)
    assert mcmc.post_warmup_state is not None
    warmup_samples = mcmc.get_samples()
    mcmc.run(random.PRNGKey(3), data)
    samples = mcmc.get_samples()
    assert len(warmup_samples['loc']) == num_warmup
    assert len(samples['loc']) == num_samples
    assert_allclose(jnp.mean(samples['loc'], 0), true_coef, atol=0.05)

    mcmc.post_warmup_state = mcmc.last_state
    mcmc.run(random.PRNGKey(3), data)
    samples = mcmc.get_samples()
    assert len(samples['loc']) == num_samples
    assert_allclose(jnp.mean(samples['loc'], 0), true_coef, atol=0.05)
コード例 #8
0
def test_chain_smoke(chain_method, compile_args):
    def model(data):
        concentration = jnp.array([1.0, 1.0, 1.0])
        p_latent = numpyro.sample('p_latent', dist.Dirichlet(concentration))
        numpyro.sample('obs', dist.Categorical(p_latent), obs=data)
        return p_latent

    data = dist.Categorical(jnp.array([0.1, 0.6, 0.3])).sample(random.PRNGKey(1), (2000,))
    kernel = NUTS(model)
    mcmc = MCMC(kernel, 2, 5, num_chains=2, chain_method=chain_method, jit_model_args=compile_args)
    mcmc.warmup(random.PRNGKey(0), data)
    mcmc.run(random.PRNGKey(1), data)
コード例 #9
0
def test_improper_prior():
    true_mean, true_std = 1., 2.
    num_warmup, num_samples = 1000, 8000

    def model(data):
        mean = numpyro.param('mean', 0.)
        std = numpyro.param('std', 1., constraint=constraints.positive)
        return numpyro.sample('obs', dist.Normal(mean, std), obs=data)

    data = dist.Normal(true_mean, true_std).sample(random.PRNGKey(1), (2000, ))
    kernel = NUTS(model=model)
    mcmc = MCMC(kernel, num_warmup, num_samples)
    mcmc.warmup(random.PRNGKey(2), data)
    mcmc.run(random.PRNGKey(2), data)
    samples = mcmc.get_samples()
    assert_allclose(np.mean(samples['mean']), true_mean, rtol=0.05)
    assert_allclose(np.mean(samples['std']), true_std, rtol=0.05)
コード例 #10
0
def test_improper_prior():
    true_mean, true_std = 1.0, 2.0
    num_warmup, num_samples = 1000, 8000

    def model(data):
        mean = numpyro.sample("mean", dist.Normal(0, 1).mask(False))
        std = numpyro.sample(
            "std", dist.ImproperUniform(dist.constraints.positive, (), ()))
        return numpyro.sample("obs", dist.Normal(mean, std), obs=data)

    data = dist.Normal(true_mean, true_std).sample(random.PRNGKey(1), (2000, ))
    kernel = NUTS(model=model)
    mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples)
    mcmc.warmup(random.PRNGKey(2), data)
    mcmc.run(random.PRNGKey(2), data)
    samples = mcmc.get_samples()
    assert_allclose(jnp.mean(samples["mean"]), true_mean, rtol=0.05)
    assert_allclose(jnp.mean(samples["std"]), true_std, rtol=0.05)
コード例 #11
0
def test_uniform_normal():
    true_coef = 0.9
    num_warmup, num_samples = 1000, 1000

    def model(data):
        alpha = numpyro.sample('alpha', dist.Uniform(0, 1))
        loc = numpyro.sample('loc', dist.Uniform(0, alpha))
        numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)

    data = true_coef + random.normal(random.PRNGKey(0), (1000, ))
    kernel = NUTS(model=model)
    mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples)
    mcmc.warmup(random.PRNGKey(2), data, collect_warmup=True)
    warmup_samples = mcmc.get_samples()
    mcmc.run(random.PRNGKey(3), data)
    samples = mcmc.get_samples()
    assert len(warmup_samples['loc']) == num_warmup
    assert len(samples['loc']) == num_samples
    assert_allclose(np.mean(samples['loc'], 0), true_coef, atol=0.05)
コード例 #12
0
def run_model(
    model_func,
    data,
    ep,
    num_samples=500,
    num_warmup=500,
    num_chains=4,
    target_accept=0.75,
    max_tree_depth=15,
    save_results=True,
    output_fname=None,
    model_kwargs=None,
    save_json=False,
    chain_method="parallel",
    heuristic_step_size=True,
):
    """
    Model run utility

    :param model_func: numpyro model
    :param data: PreprocessedData object
    :param ep: EpidemiologicalParameters object
    :param num_samples: number of samples
    :param num_warmup: number of warmup samples
    :param num_chains: number of chains
    :param target_accept: target accept
    :param max_tree_depth: maximum treedepth
    :param save_results: whether to save full results
    :param output_fname: output filename
    :param model_kwargs: model kwargs -- extra arguments for the model function
    :param save_json: whether to save json
    :param chain_method: Numpyro chain method to use
    :param heuristic_step_size: whether to find a heuristic step size
    :return: posterior_samples, warmup_samples, info_dict (dict with assorted diagnostics), Numpyro mcmc object
    """
    print(
        f"Running {num_chains} chains, {num_samples} per chain with {num_warmup} warmup steps"
    )
    nuts_kernel = NUTS(
        model_func,
        init_strategy=init_to_median,
        target_accept_prob=target_accept,
        max_tree_depth=max_tree_depth,
        find_heuristic_step_size=heuristic_step_size,
    )
    mcmc = MCMC(
        nuts_kernel,
        num_samples=num_samples,
        num_warmup=num_warmup,
        num_chains=num_chains,
        chain_method=chain_method,
    )
    rng_key = random.PRNGKey(0)

    # hmcstate = nuts_kernel.init(rng_key, 1, model_args=(data, ep))
    # nRVs = hmcstate.adapt_state.inverse_mass_matrix.size
    # inverse_mass_matrix = init_diag_inv_mass_mat * jnp.ones(nRVs)
    # mass_matrix_sqrt_inv = np.sqrt(inverse_mass_matrix)
    # mass_matrix_sqrt = 1./mass_matrix_sqrt_inv
    # hmcstate = hmcstate._replace(adapt_state=hmcstate.adapt_state._replace(inverse_mass_matrix=inverse_mass_matrix))
    # hmcstate = hmcstate._replace(adapt_state=hmcstate.adapt_state._replace(mass_matrix_sqrt_inv=mass_matrix_sqrt_inv))
    # hmcstate = hmcstate._replace(adapt_state=hmcstate.adapt_state._replace(mass_matrix_sqrt=mass_matrix_sqrt))
    # mcmc.post_warmup_state = hmcstate

    info_dict = {
        "model_name": model_func.__name__,
    }

    start = time.time()
    if model_kwargs is None:
        model_kwargs = {}

    info_dict["model_kwargs"] = model_kwargs

    # also collect some extra information for better diagonstics!
    print(f"Warmup Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    mcmc.warmup(
        rng_key,
        data,
        ep,
        **model_kwargs,
        collect_warmup=True,
        extra_fields=["num_steps", "mean_accept_prob", "adapt_state"],
    )
    mcmc.get_extra_fields()["num_steps"].block_until_ready()

    info_dict["warmup"] = {}
    info_dict["warmup"]["num_steps"] = np.array(
        mcmc.get_extra_fields()["num_steps"]).tolist()
    info_dict["warmup"]["step_size"] = np.array(
        mcmc.get_extra_fields()["adapt_state"].step_size).tolist()
    info_dict["warmup"]["inverse_mass_matrix"] = {}

    all_mass_mats = jnp.array(
        jnp.array_split(
            mcmc.get_extra_fields()["adapt_state"].inverse_mass_matrix,
            num_chains,
            axis=0,
        ))

    print(all_mass_mats.shape)

    for i in range(num_chains):
        info_dict["warmup"]["inverse_mass_matrix"][
            f"chain_{i}"] = all_mass_mats[i, -1, :].tolist()

    info_dict["warmup"]["mean_accept_prob"] = np.array(
        mcmc.get_extra_fields()["mean_accept_prob"]).tolist()

    warmup_samples = mcmc.get_samples()

    print(f"Sample Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    mcmc.run(
        rng_key,
        data,
        ep,
        **model_kwargs,
        extra_fields=["num_steps", "mean_accept_prob", "adapt_state"],
    )

    posterior_samples = mcmc.get_samples()
    # if you don't block this, the timer won't quite work properly.
    posterior_samples[list(posterior_samples.keys())[0]].block_until_ready()
    print(f"Sample Finished: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

    end = time.time()
    time_per_sample = float(end - start) / num_samples
    divergences = int(mcmc.get_extra_fields()["diverging"].sum())

    info_dict["time_per_sample"] = time_per_sample
    info_dict["total_runtime"] = float(end - start)
    info_dict["divergences"] = divergences

    info_dict["sample"] = {}
    info_dict["sample"]["num_steps"] = np.array(
        mcmc.get_extra_fields()["num_steps"]).tolist()
    info_dict["sample"]["mean_accept_prob"] = np.array(
        mcmc.get_extra_fields()["mean_accept_prob"]).tolist()
    info_dict["sample"]["step_size"] = np.array(
        mcmc.get_extra_fields()["adapt_state"].step_size).tolist()

    print(f"Sampling {num_samples} samples per chain took {end - start:.2f}s")
    print(f"There were {divergences} divergences.")

    grouped_posterior_samples = mcmc.get_samples(True)

    all_ess = np.array([])
    for k in grouped_posterior_samples.keys():
        ess = numpyro.diagnostics.effective_sample_size(
            np.asarray(grouped_posterior_samples[k]))
        all_ess = np.append(all_ess, ess)

    print(f"{np.sum(np.isnan(all_ess))}  ESS were nan")
    all_ess = all_ess[np.logical_not(np.isnan(all_ess))]

    info_dict["ess"] = {
        "med": float(np.percentile(all_ess, 50)),
        "lower": float(np.percentile(all_ess, 2.5)),
        "upper": float(np.percentile(all_ess, 97.5)),
        "min": float(np.min(all_ess)),
        "max": float(np.max(all_ess)),
    }
    print(
        f"Mean ESS: {info_dict['ess']['med']:.2f} [{info_dict['ess']['lower']:.2f} ... {info_dict['ess']['upper']:.2f}]"
    )

    if num_chains > 1:
        all_rhat = np.array([])
        for k in grouped_posterior_samples.keys():
            rhat = numpyro.diagnostics.gelman_rubin(
                np.asarray(grouped_posterior_samples[k]))
            all_rhat = np.append(all_rhat, rhat)

        print(f"{np.sum(np.isnan(all_rhat))} Rhat were nan")
        all_rhat = all_rhat[np.logical_not(np.isnan(all_rhat))]

        info_dict["rhat"] = {
            "med": float(np.percentile(all_rhat, 50)),
            "upper": float(np.percentile(all_rhat, 97.5)),
            "lower": float(np.percentile(all_rhat, 2.5)),
            "min": float(np.max(all_rhat)),
            "max": float(np.min(all_rhat)),
        }

        print(
            f"Rhat: {info_dict['rhat']['med']:.2f} [{info_dict['rhat']['lower']:.2f} ... {info_dict['rhat']['upper']:.2f}]"
        )

    if save_results:
        print("Saving .netcdf")
        try:
            inf_data = az.from_numpyro(mcmc)

            if output_fname is None:
                output_fname = f'{model_func.__name__}-{datetime.now(tz=None).strftime("%d-%m;%H-%M-%S")}.netcdf'

            az.to_netcdf(inf_data, output_fname)

            json_fname = output_fname.replace(".netcdf", ".json")
            if save_json:
                print("Saving Json")
                with open(json_fname, "w") as f:
                    json.dump(info_dict, f, ensure_ascii=False, indent=4)

        except Exception as e:
            print(e)

    return posterior_samples, warmup_samples, info_dict, mcmc