def test_numpyro_dict_params_defaults_nullcase(input_types):

    demo_params = {
        "lengthscale": input_types,
        "variance": input_types,
        "obs_noise": input_types,
    }
    with pytest.raises(ValueError):

        numpyro_dict_params(demo_params)
def test_numpyro_marginal_ll_tfp_priors_type(n_samples, n_features, n_latents, dtype):

    # create sample data
    ds = _gen_training_data(n_samples, n_features, n_latents)

    # convert to tyle
    ds = jax.tree_util.tree_map(lambda x: x.astype(dtype), ds)

    # initialize parameters
    params, posterior = _get_conjugate_posterior_params()

    # convert to numpyro-style params
    numpyro_params = numpyro_dict_params(params)

    # convert to priors
    numpyro_params = add_priors(numpyro_params, tfd.LogNormal(0.0, 10.0))

    # initialize numpyro-style GP model
    npy_model = numpyro_marginal_ll(posterior, numpyro_params)

    # do one forward pass with context
    with numpyro.handlers.seed(rng_seed=KEY):
        pred = npy_model(ds)

        chex.assert_equal(pred.dtype, ds.y.dtype)
def test_numpyro_dict_params_defaults_float():

    demo_params = {
        "lengthscale": 1.0,
        "variance": 1.0,
        "obs_noise": 1.0,
    }

    numpyro_params = numpyro_dict_params(demo_params)

    assert set(numpyro_params) == set(demo_params.keys())
    for ikey, iparam in demo_params.items():
        # check keys exist for param
        assert set(numpyro_params[ikey].keys()) == set(
            ("init_value", "constraint", "param_type"))
        # check init value is the same as initial value
        chex.assert_equal(numpyro_params[ikey]["init_value"], iparam)
        # check default constraint is positive
        chex.assert_equal(numpyro_params[ikey]["constraint"],
                          constraints.positive)
        # check if param type is param
        chex.assert_equal(numpyro_params[ikey]["param_type"], "param")

    # check we didn't modify original dictionary
    chex.assert_equal(
        demo_params,
        {
            "lengthscale": 1.0,
            "variance": 1.0,
            "obs_noise": 1.0,
        },
    )
def test_numpyro_add_constraints_str(variable, constraint):

    gpjax_params = _get_conjugate_posterior_params()
    numpyro_params = numpyro_dict_params(gpjax_params)

    # add constraint
    new_numpyro_params = add_constraints(numpyro_params, variable, constraint)

    # check if constraint in new dictionary
    chex.assert_equal(new_numpyro_params[variable]["constraint"], constraint)

    # check we didn't modify original dictionary
    chex.assert_equal(gpjax_params, _get_conjugate_posterior_params())
def test_numpyro_add_constraints_all(constraint):

    gpjax_params = _get_conjugate_posterior_params()
    numpyro_params = numpyro_dict_params(gpjax_params)

    # add constraint
    new_numpyro_params = add_constraints(numpyro_params, constraint)
    for iparams in new_numpyro_params.values():

        # check if constraint in new dictionary
        chex.assert_equal(iparams["constraint"], constraint)

    # check we didn't modify original dictionary
    chex.assert_equal(gpjax_params, _get_conjugate_posterior_params())
Esempio n. 6
0
def test_numpyro_add_priors_all(prior):

    gpjax_params = _get_conjugate_posterior_params()
    numpyro_params = numpyro_dict_params(gpjax_params)

    # add constraint
    new_numpyro_params = add_priors(numpyro_params, prior)
    for iparams in new_numpyro_params.values():

        # check if constraint in new dictionary
        chex.assert_equal(iparams["param_type"], "prior")
        chex.assert_equal(iparams["prior"], prior)

    # check we didn't modify original dictionary
    chex.assert_equal(gpjax_params, _get_conjugate_posterior_params())
Esempio n. 7
0
def test_numpyro_dict_priors_defaults_tfp():

    demo_priors = {
        "lengthscale": tfd.LogNormal(loc=0.0, scale=1.0),
        "variance": tfd.LogNormal(loc=0.0, scale=1.0),
        "obs_noise": tfd.LogNormal(loc=0.0, scale=1.0),
    }

    numpyro_params = numpyro_dict_params(demo_priors)

    assert set(numpyro_params) == set(demo_priors.keys())
    for ikey, iparam in demo_priors.items():
        # check keys exist for param
        assert set(numpyro_params[ikey].keys()) == set(("prior", "param_type"))
        # check init value is the same as initial value
        chex.assert_equal(numpyro_params[ikey]["prior"], iparam)
Esempio n. 8
0
def test_numpyro_add_priors_dict(variable, prior):

    gpjax_params = _get_conjugate_posterior_params()
    numpyro_params = numpyro_dict_params(gpjax_params)

    # create new dictionary
    new_param_dict = {str(variable): prior}

    # add constraint
    new_numpyro_params = add_priors(numpyro_params, new_param_dict)

    # check if constraint in new dictionary
    chex.assert_equal(new_numpyro_params[variable]["param_type"], "prior")
    chex.assert_equal(new_numpyro_params[variable]["prior"], prior)

    # check we didn't modify original dictionary
    chex.assert_equal(gpjax_params, _get_conjugate_posterior_params())
def test_numpyro_marginal_ll_params_shape(n_samples, n_features, n_latents):

    # create sample data
    ds = _gen_training_data(n_samples, n_features, n_latents)

    # initialize parameters
    params, posterior = _get_conjugate_posterior_params()

    # convert to numpyro-style params
    numpyro_params = numpyro_dict_params(params)

    # initialize numpyro-style GP model
    npy_model = numpyro_marginal_ll(posterior, numpyro_params)

    # do one forward pass with context
    with numpyro.handlers.seed(rng_seed=KEY):
        pred = npy_model(ds)

        chex.assert_equal_shape([ds.y.squeeze(), pred])
def test_numpyro_marginal_ll_params():

    # create sample data
    ds = _gen_training_data(10, 10, 2)

    # initialize parameters
    params, posterior = _get_conjugate_posterior_params()

    # convert to numpyro-style params
    numpyro_params = numpyro_dict_params(params)

    # initialize numpyro-style GP model
    npy_model = numpyro_marginal_ll(posterior, numpyro_params)

    # do one forward pass with context
    with numpyro.handlers.seed(rng_seed=KEY):
        model_params = numpyro.handlers.trace(npy_model).get_trace(ds)

    assert set(numpyro_params) <= set(model_params)