def test_haiku_state_dropout_smoke(dropout, batchnorm): import haiku as hk def fn(x): if dropout: x = hk.dropout(hk.next_rng_key(), 0.5, x) if batchnorm: x = hk.BatchNorm(create_offset=True, create_scale=True, decay_rate=0.001)( x, is_training=True ) return x def model(): transform = hk.transform_with_state if batchnorm else hk.transform nn = haiku_module("nn", transform(fn), apply_rng=dropout, input_shape=(4, 3)) x = numpyro.sample("x", dist.Normal(0, 1).expand([4, 3]).to_event(2)) if dropout: y = nn(numpyro.prng_key(), x) else: y = nn(x) numpyro.deterministic("y", y) with handlers.trace(model) as tr, handlers.seed(rng_seed=0): model() if batchnorm: assert set(tr.keys()) == {"nn$params", "nn$state", "x", "y"} assert tr["nn$state"]["type"] == "mutable" else: assert set(tr.keys()) == {"nn$params", "x", "y"}
def single_prediction(val): rng_key, samples = val model_trace = trace(seed(substitute(model, samples), rng_key)).get_trace(*model_args, **model_kwargs) if return_sites is not None: if return_sites == '': sites = { k for k, site in model_trace.items() if site['type'] != 'plate' } else: sites = return_sites else: sites = { k for k, site in model_trace.items() if (site['type'] == 'sample' and k not in samples) or ( site['type'] == 'deterministic') } return { name: site['value'] for name, site in model_trace.items() if name in sites }
def log_likelihood(params_flat, subsample_indices=None): if subsample_indices is None: subsample_indices = { k: jnp.arange(v[0]) for k, v in subsample_plate_sizes.items() } params = unravel_fn(params_flat) with warnings.catch_warnings(): warnings.simplefilter("ignore") params = { name: biject_to(prototype_trace[name]["fn"].support)(value) for name, value in params.items() } with block(), trace() as tr, substitute( data=subsample_indices), substitute(data=params): model(*model_args, **model_kwargs) log_lik = {} for site in tr.values(): if site["type"] == "sample" and site["is_observed"]: for frame in site["cond_indep_stack"]: if frame.name in log_lik: log_lik[frame.name] += _sum_all_except_at_dim( site["fn"].log_prob(site["value"]), frame.dim) else: log_lik[frame.name] = _sum_all_except_at_dim( site["fn"].log_prob(site["value"]), frame.dim) return log_lik
def constrain_fn(model, transforms, model_args, model_kwargs, params, return_deterministic=False): """ (EXPERIMENTAL INTERFACE) Gets value at each latent site in `model` given unconstrained parameters `params`. The `transforms` is used to transform these unconstrained parameters to base values of the corresponding priors in `model`. If a prior is a transformed distribution, the corresponding base value lies in the support of base distribution. Otherwise, the base value lies in the support of the distribution. :param model: a callable containing NumPyro primitives. :param dict transforms: dictionary of transforms keyed by names. Names in `transforms` and `params` should align. :param tuple model_args: args provided to the model. :param dict model_kwargs: kwargs provided to the model. :param dict params: dictionary of unconstrained values keyed by site names. :param bool return_deterministic: whether to return the value of `deterministic` sites from the model. Defaults to `False`. :return: `dict` of transformed params. """ params_constrained = transform_fn(transforms, params) substituted_model = substitute(model, base_param_map=params_constrained) model_trace = trace(substituted_model).get_trace(*model_args, **model_kwargs) return { k: v['value'] for k, v in model_trace.items() if (k in params) or ( return_deterministic and v['type'] == 'deterministic') }
def get_model_transforms(rng_key, model, model_args=(), model_kwargs=None): model_kwargs = {} if model_kwargs is None else model_kwargs seeded_model = seed(model, rng_key if rng_key.ndim == 1 else rng_key[0]) model_trace = trace(seeded_model).get_trace(*model_args, **model_kwargs) inv_transforms = {} # model code may need to be replayed in the presence of dynamic constraints # or deterministic sites replay_model = False for k, v in model_trace.items(): if v['type'] == 'sample' and not v['is_observed']: if v['intermediates']: inv_transforms[k] = biject_to(v['fn'].base_dist.support) replay_model = True else: inv_transforms[k] = biject_to(v['fn'].support) elif v['type'] == 'param': constraint = v['kwargs'].pop('constraint', real) transform = biject_to(constraint) if isinstance(transform, ComposeTransform): inv_transforms[k] = transform.parts[0] replay_model = True else: inv_transforms[k] = transform elif v['type'] == 'deterministic': replay_model = True return inv_transforms, replay_model
def constrain_fn(model, model_args, model_kwargs, transforms, params): """ Gets value at each latent site in `model` given unconstrained parameters `params`. The `transforms` is used to transform these unconstrained parameters to base values of the corresponding priors in `model`. If a prior is a transformed distribution, the corresponding base value lies in the support of base distribution. Otherwise, the base value lies in the support of the distribution. :param model: a callable containing NumPyro primitives. :param tuple model_args: args provided to the model. :param dict model_kwargs`: kwargs provided to the model. :param dict transforms: dictionary of transforms keyed by names. Names in `transforms` and `params` should align. :param dict params: dictionary of unconstrained values keyed by site names. :return: `dict` of transformed params. """ params_constrained = transform_fn(transforms, params) substituted_model = substitute(model, base_param_map=params_constrained) model_trace = trace(substituted_model).get_trace(*model_args, **model_kwargs) return { k: model_trace[k]['value'] for k, v in params.items() if k in model_trace }
def _get_model_transforms(model, model_args=(), model_kwargs=None): model_kwargs = {} if model_kwargs is None else model_kwargs model_trace = trace(model).get_trace(*model_args, **model_kwargs) inv_transforms = {} # model code may need to be replayed in the presence of deterministic sites replay_model = False has_enumerate_support = False for k, v in model_trace.items(): if v['type'] == 'sample' and not v['is_observed']: if v['fn'].is_discrete: has_enumerate_support = True if not v['fn'].has_enumerate_support: raise RuntimeError("MCMC only supports continuous sites or discrete sites " f"with enumerate support, but got {type(v['fn']).__name__}.") else: support = v['fn'].support inv_transforms[k] = biject_to(support) # XXX: the following code filters out most situations with dynamic supports args = () if isinstance(support, constraints._GreaterThan): args = ('lower_bound',) elif isinstance(support, constraints._Interval): args = ('lower_bound', 'upper_bound') for arg in args: if not isinstance(getattr(support, arg), (int, float)): replay_model = True elif v['type'] == 'deterministic': replay_model = True return inv_transforms, replay_model, has_enumerate_support, model_trace
def log_density(model, model_args, model_kwargs, params): """ (EXPERIMENTAL INTERFACE) Computes log of joint density for the model given latent values ``params``. :param model: Python callable containing NumPyro primitives. :param tuple model_args: args provided to the model. :param dict model_kwargs: kwargs provided to the model. :param dict params: dictionary of current parameter values keyed by site name. :return: log of joint density and a corresponding model trace """ model = substitute(model, param_map=params) model_trace = trace(model).get_trace(*model_args, **model_kwargs) log_joint = jnp.array(0.) for site in model_trace.values(): if site['type'] == 'sample': value = site['value'] intermediates = site['intermediates'] scale = site['scale'] if intermediates: log_prob = site['fn'].log_prob(value, intermediates) else: log_prob = site['fn'].log_prob(value) if (scale is not None) and (not is_identically_one(scale)): log_prob = scale * log_prob log_prob = jnp.sum(log_prob) log_joint = log_joint + log_prob return log_joint, model_trace
def sample_prior_predictive(rng_key, model, model_args, substitutes=None, with_intermediates=False, **kwargs): """ Samples once from the prior predictive distribution. Individual sample sites, as designated by `sample`, can be frozen to pre-determined values given in `substitutes`. In that case, values for these sites are not actually sampled but the value provided in `substitutes` is returned as the sample. This facilitates conditional sampling. Note that if the model function is written in such a way that it returns, e.g., multiple observations from a single prior draw, the same is true for the values returned by this function. :param rng_key: Jax PRNG key :param model: Function representing the model using numpyro distributions and the `sample` primitive :param model_args: Arguments to the model function :param substitutes: An optional dictionary of frozen substitutes for sample sites. :param with_intermediates: If True, intermediate(/latent) samples from sample site distributions are included in the result. :param **kwargs: Keyword arguments passed to the model function. :return: Dictionary of sampled values associated with the names given via `sample()` in the model. If with_intermediates is True, dictionary values are tuples where the first element is the final sample values and the second element is a list of intermediate values. """ if substitutes is None: substitutes = dict() model = seed(substitute(model, data=substitutes), rng_key) t = trace(model).get_trace(*model_args, **kwargs) return get_samples_from_trace(t, with_intermediates)
def body_fn(state): i, key, _, _ = state key, subkey = random.split(key) if radius is None or prototype_params is None: # Wrap model in a `substitute` handler to initialize from `init_loc_fn`. seeded_model = substitute(seed(model, subkey), substitute_fn=init_strategy) model_trace = trace(seeded_model).get_trace(*model_args, **model_kwargs) constrained_values, inv_transforms = {}, {} for k, v in model_trace.items(): if v['type'] == 'sample' and not v['is_observed'] and not v['fn'].is_discrete: constrained_values[k] = v['value'] inv_transforms[k] = biject_to(v['fn'].support) params = transform_fn(inv_transforms, {k: v for k, v in constrained_values.items()}, invert=True) else: # this branch doesn't require tracing the model params = {} for k, v in prototype_params.items(): if k in init_values: params[k] = init_values[k] else: params[k] = random.uniform(subkey, jnp.shape(v), minval=-radius, maxval=radius) key, subkey = random.split(key) potential_fn = partial(potential_energy, model, model_args, model_kwargs, enum=enum) pe, z_grad = value_and_grad(potential_fn)(params) z_grad_flat = ravel_pytree(z_grad)[0] is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat)) return i + 1, key, (params, pe, z_grad), is_valid
def test_counterfactual_query(intervene, observe, flip): # x -> y -> z -> w sites = ["x", "y", "z", "w"] observations = {"x": 1., "y": None, "z": 1., "w": 1.} interventions = {"x": None, "y": 0., "z": 2., "w": 1.} def model(): with handlers.seed(rng_seed=0): x = numpyro.sample("x", dist.Normal(0, 1)) y = numpyro.sample("y", dist.Normal(x, 1)) z = numpyro.sample("z", dist.Normal(y, 1)) w = numpyro.sample("w", dist.Normal(z, 1)) return dict(x=x, y=y, z=z, w=w) if not flip: if intervene: model = handlers.do(model, data=interventions) if observe: model = handlers.condition(model, data=observations) elif flip and intervene and observe: model = handlers.do(handlers.condition(model, data=observations), data=interventions) with handlers.trace() as tr: actual_values = model() for name in sites: # case 1: purely observational query like handlers.condition if not intervene and observe: if observations[name] is not None: assert tr[name]['is_observed'] assert_allclose(observations[name], actual_values[name]) assert_allclose(observations[name], tr[name]['value']) if interventions[name] != observations[name]: if interventions[name] is not None: assert_raises(AssertionError, assert_allclose, interventions[name], actual_values[name]) # case 2: purely interventional query like old handlers.do elif intervene and not observe: assert not tr[name]['is_observed'] if interventions[name] is not None: assert_allclose(interventions[name], actual_values[name]) if observations[name] is not None: assert_raises(AssertionError, assert_allclose, observations[name], tr[name]['value']) if interventions[name] is not None: assert_raises(AssertionError, assert_allclose, interventions[name], tr[name]['value']) # case 3: counterfactual query mixing intervention and observation elif intervene and observe: if observations[name] is not None: assert tr[name]['is_observed'] assert_allclose(observations[name], tr[name]['value']) if interventions[name] is not None: assert_allclose(interventions[name], actual_values[name]) if interventions[name] != observations[name]: if interventions[name] is not None: assert_raises(AssertionError, assert_allclose, interventions[name], tr[name]['value'])
def test_haiku_module(): X = np.arange(100) Y = 2 * X + 2 with handlers.trace() as haiku_tr, handlers.seed(rng_seed=1): haiku_model(X, Y) assert haiku_tr["nn$params"]['value']['linear']['w'].shape == (100, 100) assert haiku_tr["nn$params"]['value']['linear']['b'].shape == (100,)
def predict(model, rng_key, samples, X, D_H, sigma_obs=None): model = handlers.substitute(handlers.seed(model, rng_key), samples) # note that Y will be sampled in the model because we pass Y=None here model_trace = handlers.trace(model).get_trace(X=X, Y=None, D_H=D_H, sigma_obs=sigma_obs) return model_trace['Y']['value']
def test_load_numpyro_model_simple_working_model(self): """ only verifies that no errors occur and all returned functions are not None """ model, guide, preprocess, postprocess = load_custom_numpyro_model( './tests/models/simple_gauss_model.py', Namespace(), [], pd.DataFrame()) self.assertIsNotNone(model) self.assertIsNotNone(guide) self.assertIsNotNone(preprocess) self.assertIsNotNone(postprocess) z = np.ones((10, 2)) samples_with_obs = trace(seed(model, jax.random.PRNGKey(0))).get_trace( z, num_obs_total=10) self.assertTrue(np.allclose(samples_with_obs['x']['value'], z)) samples_no_obs = trace(seed( model, jax.random.PRNGKey(0))).get_trace(num_obs_total=10) self.assertEqual(samples_no_obs['x']['value'].shape, (1, 2)) self.assertFalse(np.allclose(samples_no_obs['x']['value'], z))
def test_distribution_1(temperature): # +-------+ # z --|--> x | # +-------+ num_particles = 10000 data = np.array([1.0, 2.0, 3.0]) @config_enumerate def model(z=None): p = numpyro.param("p", np.array([0.75, 0.25])) iz = numpyro.sample("z", dist.Categorical(p), obs=z) z = jnp.array([0.0, 1.0])[iz] logger.info("z.shape = {}".format(z.shape)) with numpyro.plate("data", 3): numpyro.sample("x", dist.Normal(z, 1.0), obs=data) first_available_dim = -3 vectorized_model = ( model if temperature == 0 else vectorize_model(model, num_particles, dim=-2) ) sampled_model = infer_discrete( vectorized_model, first_available_dim, temperature, rng_key=random.PRNGKey(1) ) sampled_trace = handlers.trace(sampled_model).get_trace() conditioned_traces = { z: handlers.trace(model).get_trace(z=np.array(z)) for z in [0, 1] } # Check posterior over z. actual_z_mean = sampled_trace["z"]["value"].astype(float).mean() if temperature: expected_z_mean = 1 / ( 1 + np.exp( log_prob_sum(conditioned_traces[0]) - log_prob_sum(conditioned_traces[1]) ) ) else: expected_z_mean = ( log_prob_sum(conditioned_traces[1]) > log_prob_sum(conditioned_traces[0]) ).astype(float) expected_max = max(log_prob_sum(t) for t in conditioned_traces.values()) actual_max = log_prob_sum(sampled_trace) assert_allclose(expected_max, actual_max, atol=1e-5) assert_allclose(actual_z_mean, expected_z_mean, atol=1e-2 if temperature else 1e-5)
def get_log_probs(sample, seed=0): with handlers.trace() as tr, handlers.seed( model, seed), handlers.substitute(data=sample): model(*model_args, **model_kwargs) return { name: site["fn"].log_prob(site["value"]) for name, site in tr.items() if site["type"] == "sample" }
def log_density(model, model_args, model_kwargs, params): model = substitute(model, params) model_trace = trace(model).get_trace(*model_args, **model_kwargs) log_joint = 0. for site in model_trace.values(): if site['type'] == 'sample': log_joint = log_joint + np.sum(site['fn'].log_prob(site['value'])) return log_joint, model_trace
def test_flax_module(): X = np.arange(100) Y = 2 * X + 2 with handlers.trace() as flax_tr, handlers.seed(rng_seed=1): flax_model(X, Y) assert flax_tr["nn$params"]['value']['kernel'].shape == (100, 100) assert flax_tr["nn$params"]['value']['bias'].shape == (100,)
def log_likelihood( rng_key: np.ndarray, params: np.ndarray, model: Callable, *args: Any, **kwargs: Any ) -> np.ndarray: model = handlers.condition(model, params) model_trace = handlers.trace(model).get_trace(*args, **kwargs) obs_node = model_trace["obs"] return obs_node["fn"].log_prob(obs_node["value"])
def single_loglik(samples): model_trace = trace(substitute(model, samples)).get_trace(*args, **kwargs) return { name: site['fn'].log_prob(site['value']) for name, site in model_trace.items() if site['type'] == 'sample' and site['is_observed'] }
def sample_posterior_predictive(rng_key, model, model_args, guide, guide_args, params, with_intermediates=False, **kwargs): """ Samples once from the posterior predictive distribution. Note that if the model function is written in such a way that it returns, e.g., multiple observations from a single posterior draw, the same is true for the values returned by this function. :param rng_key: Jax PRNG key :param model: Function representing the model using numpyro distributions and the `sample` primitive :param model_args: Arguments to the model function :param guide: Function representing the variational distribution (the guide) using numpyro distributions as well as the `sample` and `param` primitives :param guide_args: Arguments to the guide function :param params: A dictionary providing values for the parameters designated by call to `param` in the guide :param with_intermediates: If True, intermediate(/latent) samples from sample site distributions are included in the result. :param **kwargs: Keyword arguments passed to the model and guide functions. :return: Dictionary of sampled values associated with the names given via `sample()` in the model. If with_intermediates is True, dictionary values are tuples where the first element is the final sample values and the second element is a list of intermediate values. """ model_rng_key, guide_rng_key = jax.random.split(rng_key) guide = seed(substitute(guide, data=params), guide_rng_key) guide_samples = get_samples_from_trace( trace(guide).get_trace(*guide_args, **kwargs), with_intermediates ) model_params = dict(**params) if with_intermediates: model_params.update({k: v[0] for k, v in guide_samples.items()}) else: model_params.update({k: v for k, v in guide_samples.items()}) model = seed(substitute(model, data=model_params), model_rng_key) model_samples = get_samples_from_trace( trace(model).get_trace(*model_args, **kwargs), with_intermediates ) guide_samples.update(model_samples) return guide_samples
def _get_log_probs(model, model_args, model_kwargs, sample): # Note: We use seed 0 for parameter initialization. with handlers.trace() as tr, handlers.seed( rng_seed=0), handlers.substitute(data=sample): model(*model_args, **model_kwargs) return { name: site["fn"].log_prob(site["value"]) for name, site in tr.items() if site["type"] == "sample" }
def test_scope_frames(): def model(y): mu = numpyro.sample("mu", dist.Normal()) sigma = numpyro.sample("sigma", dist.HalfNormal()) with numpyro.plate("plate1", y.shape[0]): numpyro.sample("y", dist.Normal(mu, sigma), obs=y) scope_prefix = "scope" scoped_model = handlers.scope(model, prefix=scope_prefix) obs = np.random.normal(size=(10,)) trace = handlers.trace(handlers.seed(model, 0)).get_trace(obs) scoped_trace = handlers.trace(handlers.seed(scoped_model, 0)).get_trace(obs) assert trace["y"]["cond_indep_stack"][0].name in trace assert scoped_trace[f"{scope_prefix}/y"]["cond_indep_stack"][0].name in scoped_trace
def predict(model, at_bats, hits, z, rng, player_names, train=True): header = model.__name__ + (' - TRAIN' if train else ' - TEST') model = substitute(seed(model, rng), z) model_trace = trace(model).get_trace(at_bats) predictions = model_trace['obs']['value'] print_results('=' * 30 + header + '=' * 30, predictions, player_names, at_bats, hits) if not train: model = substitute(model, z) model_trace = trace(model).get_trace(at_bats, hits) log_joint = 0. for site in model_trace.values(): site_log_prob = site['fn'].log_prob(site['value']) log_joint = log_joint + np.sum( site_log_prob.reshape(site_log_prob.shape[:1] + (-1, )), -1) log_post_density = logsumexp(log_joint) - np.log( np.shape(log_joint)[0]) print('\nPosterior log density: {:.2f}\n'.format(log_post_density))
def predict(model, rng_key, samples, p, t, D_H, data_type): model = handlers.substitute(handlers.seed(model, rng_key), samples) model_trace = handlers.trace(model).get_trace(p=p, t=t, Y=None, F=None, D_H=D_H, data_type=data_type) return model_trace['Y']['value']
def single_prediction(rng, samples): model_trace = trace(seed(condition(model, samples), rng)).get_trace(*args, **kwargs) sites = model_trace.keys() - samples.keys( ) if return_sites is None else return_sites return { name: site['value'] for name, site in model_trace.items() if name in sites }
def _setup_prototype(self, *args, **kwargs): # run the model so we can inspect its structure rng_key = numpyro.sample("_{}_rng_key_setup".format(self.prefix), dist.PRNGIdentity()) model = handlers.seed(self.model, rng_key) self.prototype_trace = handlers.block(handlers.trace(model).get_trace)( *args, **kwargs) self._args = args self._kwargs = kwargs
def single_loglik(samples): substituted_model = (substitute(model, samples) if isinstance( samples, dict) else model) model_trace = trace(substituted_model).get_trace(*args, **kwargs) return { name: site["fn"].log_prob(site["value"]) for name, site in model_trace.items() if site["type"] == "sample" and site["is_observed"] }
def initialize_model(rng, model, model_args, model_kwargs): model = seed(model, rng) model_trace = trace(model).get_trace(*model_args, **model_kwargs) init_params = { k: v['value'] for k, v in model_trace.items() if v['type'] == 'sample' and not v['is_observed'] } return init_params, potential_energy(model, model_args, model_kwargs)
def forecast(future, rng_key, sample, y, n_obs): Z_exp, obs_last = handlers.substitute(ar_k, sample)(n_obs, y) forecast_model = handlers.seed(_forecast, rng_key) forecast_trace = handlers.trace(forecast_model).get_trace( future, sample, Z_exp, n_obs) results = [ np.clip(forecast_trace["yf[{}]".format(t)]["value"], a_min=1e-30) for t in range(future) ] return np.stack(results, axis=0)