def test_numpyro_dict_params_defaults_nullcase(input_types): demo_params = { "lengthscale": input_types, "variance": input_types, "obs_noise": input_types, } with pytest.raises(ValueError): numpyro_dict_params(demo_params)
def test_numpyro_marginal_ll_tfp_priors_type(n_samples, n_features, n_latents, dtype): # create sample data ds = _gen_training_data(n_samples, n_features, n_latents) # convert to tyle ds = jax.tree_util.tree_map(lambda x: x.astype(dtype), ds) # initialize parameters params, posterior = _get_conjugate_posterior_params() # convert to numpyro-style params numpyro_params = numpyro_dict_params(params) # convert to priors numpyro_params = add_priors(numpyro_params, tfd.LogNormal(0.0, 10.0)) # initialize numpyro-style GP model npy_model = numpyro_marginal_ll(posterior, numpyro_params) # do one forward pass with context with numpyro.handlers.seed(rng_seed=KEY): pred = npy_model(ds) chex.assert_equal(pred.dtype, ds.y.dtype)
def test_numpyro_dict_params_defaults_float(): demo_params = { "lengthscale": 1.0, "variance": 1.0, "obs_noise": 1.0, } numpyro_params = numpyro_dict_params(demo_params) assert set(numpyro_params) == set(demo_params.keys()) for ikey, iparam in demo_params.items(): # check keys exist for param assert set(numpyro_params[ikey].keys()) == set( ("init_value", "constraint", "param_type")) # check init value is the same as initial value chex.assert_equal(numpyro_params[ikey]["init_value"], iparam) # check default constraint is positive chex.assert_equal(numpyro_params[ikey]["constraint"], constraints.positive) # check if param type is param chex.assert_equal(numpyro_params[ikey]["param_type"], "param") # check we didn't modify original dictionary chex.assert_equal( demo_params, { "lengthscale": 1.0, "variance": 1.0, "obs_noise": 1.0, }, )
def test_numpyro_add_constraints_str(variable, constraint): gpjax_params = _get_conjugate_posterior_params() numpyro_params = numpyro_dict_params(gpjax_params) # add constraint new_numpyro_params = add_constraints(numpyro_params, variable, constraint) # check if constraint in new dictionary chex.assert_equal(new_numpyro_params[variable]["constraint"], constraint) # check we didn't modify original dictionary chex.assert_equal(gpjax_params, _get_conjugate_posterior_params())
def test_numpyro_add_constraints_all(constraint): gpjax_params = _get_conjugate_posterior_params() numpyro_params = numpyro_dict_params(gpjax_params) # add constraint new_numpyro_params = add_constraints(numpyro_params, constraint) for iparams in new_numpyro_params.values(): # check if constraint in new dictionary chex.assert_equal(iparams["constraint"], constraint) # check we didn't modify original dictionary chex.assert_equal(gpjax_params, _get_conjugate_posterior_params())
def test_numpyro_add_priors_all(prior): gpjax_params = _get_conjugate_posterior_params() numpyro_params = numpyro_dict_params(gpjax_params) # add constraint new_numpyro_params = add_priors(numpyro_params, prior) for iparams in new_numpyro_params.values(): # check if constraint in new dictionary chex.assert_equal(iparams["param_type"], "prior") chex.assert_equal(iparams["prior"], prior) # check we didn't modify original dictionary chex.assert_equal(gpjax_params, _get_conjugate_posterior_params())
def test_numpyro_dict_priors_defaults_tfp(): demo_priors = { "lengthscale": tfd.LogNormal(loc=0.0, scale=1.0), "variance": tfd.LogNormal(loc=0.0, scale=1.0), "obs_noise": tfd.LogNormal(loc=0.0, scale=1.0), } numpyro_params = numpyro_dict_params(demo_priors) assert set(numpyro_params) == set(demo_priors.keys()) for ikey, iparam in demo_priors.items(): # check keys exist for param assert set(numpyro_params[ikey].keys()) == set(("prior", "param_type")) # check init value is the same as initial value chex.assert_equal(numpyro_params[ikey]["prior"], iparam)
def test_numpyro_add_priors_dict(variable, prior): gpjax_params = _get_conjugate_posterior_params() numpyro_params = numpyro_dict_params(gpjax_params) # create new dictionary new_param_dict = {str(variable): prior} # add constraint new_numpyro_params = add_priors(numpyro_params, new_param_dict) # check if constraint in new dictionary chex.assert_equal(new_numpyro_params[variable]["param_type"], "prior") chex.assert_equal(new_numpyro_params[variable]["prior"], prior) # check we didn't modify original dictionary chex.assert_equal(gpjax_params, _get_conjugate_posterior_params())
def test_numpyro_marginal_ll_params_shape(n_samples, n_features, n_latents): # create sample data ds = _gen_training_data(n_samples, n_features, n_latents) # initialize parameters params, posterior = _get_conjugate_posterior_params() # convert to numpyro-style params numpyro_params = numpyro_dict_params(params) # initialize numpyro-style GP model npy_model = numpyro_marginal_ll(posterior, numpyro_params) # do one forward pass with context with numpyro.handlers.seed(rng_seed=KEY): pred = npy_model(ds) chex.assert_equal_shape([ds.y.squeeze(), pred])
def test_numpyro_marginal_ll_params(): # create sample data ds = _gen_training_data(10, 10, 2) # initialize parameters params, posterior = _get_conjugate_posterior_params() # convert to numpyro-style params numpyro_params = numpyro_dict_params(params) # initialize numpyro-style GP model npy_model = numpyro_marginal_ll(posterior, numpyro_params) # do one forward pass with context with numpyro.handlers.seed(rng_seed=KEY): model_params = numpyro.handlers.trace(npy_model).get_trace(ds) assert set(numpyro_params) <= set(model_params)