def predict(model, at_bats, hits, z, rng, player_names, train=True): header = model.__name__ + (' - TRAIN' if train else ' - TEST') model = substitute(seed(model, rng), z) model_trace = trace(model).get_trace(at_bats) predictions = model_trace['obs']['value'] print_results('=' * 30 + header + '=' * 30, predictions, player_names, at_bats, hits) if not train: model = substitute(model, z) model_trace = trace(model).get_trace(at_bats, hits) log_joint = 0. for site in model_trace.values(): site_log_prob = site['fn'].log_prob(site['value']) log_joint = log_joint + np.sum( site_log_prob.reshape(site_log_prob.shape[:1] + (-1, )), -1) log_post_density = logsumexp(log_joint) - np.log( np.shape(log_joint)[0]) print('\nPosterior log density: {:.2f}\n'.format(log_post_density))
def test_substitute(): def model(): x = numpyro.param('x', None) y = handlers.substitute( lambda: numpyro.param('y', None) * numpyro.param('x', None), {'y': x})() return x + y assert handlers.substitute(model, {'x': 3.})() == 12.
def unpack_single_latent(latent): unpacked_samples = self._unpack_latent(latent) if self._has_transformed_dist: # first, substitute to `param` statements in model model = handlers.substitute(self.model, params) return constrain_fn(model, self._inv_transforms, model_args, model_kwargs, unpacked_samples) else: return transform_fn(self._inv_transforms, unpacked_samples)
def single_loglik(samples): substituted_model = (substitute(model, samples) if isinstance( samples, dict) else model) model_trace = trace(substituted_model).get_trace(*args, **kwargs) return { name: site["fn"].log_prob(site["value"]) for name, site in model_trace.items() if site["type"] == "sample" and site["is_observed"] }
def test_substitute(): def model(): x = numpyro.param("x", None) y = handlers.substitute( lambda: numpyro.param("y", None) * numpyro.param("x", None), {"y": x} )() return x + y assert handlers.substitute(model, {"x": 3.0})() == 12.0
def get_log_probs(sample): # Note: We use seed 0 for parameter initialization. with handlers.trace() as tr, handlers.seed( rng_seed=0), handlers.substitute(data=sample): model(*model_args, **model_kwargs) return { name: site["fn"].log_prob(site["value"]) for name, site in tr.items() if site["type"] == "sample" }
def single_loglik(samples): substituted_model = substitute(model, samples) if isinstance( samples, dict) else model model_trace = trace(substituted_model).get_trace(*args, **kwargs) return { name: site['fn'].log_prob(site['value']) for name, site in model_trace.items() if site['type'] == 'sample' and site['is_observed'] }
def predict(model, rng_key, samples, p, t, D_H, data_type): model = handlers.substitute(handlers.seed(model, rng_key), samples) model_trace = handlers.trace(model).get_trace(p=p, t=t, Y=None, F=None, D_H=D_H, data_type=data_type) return model_trace['Y']['value']
def log_prior(params): with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning) dummy_subsample = { k: jnp.array([], dtype=jnp.int32) for k in subsample_plate_sizes } with block(), substitute(data=dummy_subsample): prior_prob, _ = log_density(model, model_kwargs, params) return prior_prob
def init_fn(rng, model_args=(), guide_args=(), params=None): """ :param jax.random.PRNGKey rng: random number generator seed. :param tuple model_args: arguments to the model (these can possibly vary during the course of fitting). :param tuple guide_args: arguments to the guide (these can possibly vary during the course of fitting). :param dict params: initial parameter values to condition on. This can be useful for initializing neural networks using more specialized methods rather than sampling from the prior. :return: tuple containing initial optimizer state, and `constrain_fn`, a callable that transforms unconstrained parameter values from the optimizer to the specified constrained domain """ assert isinstance(model_args, tuple) assert isinstance(guide_args, tuple) model_init, guide_init = _seed(model, guide, rng) if params is None: params = {} else: model_init = substitute(model_init, params) guide_init = substitute(guide_init, params) guide_trace = trace(guide_init).get_trace(*guide_args, **kwargs) model_trace = trace(model_init).get_trace(*model_args, **kwargs) inv_transforms = {} # NB: params in model_trace will be overwritten by params in guide_trace for site in list(model_trace.values()) + list(guide_trace.values()): if site['type'] == 'param': constraint = site['kwargs'].pop('constraint', constraints.real) transform = biject_to(constraint) if isinstance(transform, ComposeTransform): base_transform = transform.parts[0] inv_transforms[site['name']] = base_transform params[site['name']] = base_transform( transform.inv(site['value'])) else: inv_transforms[site['name']] = transform params[site['name']] = site['value'] nonlocal constrain_fn constrain_fn = jax.partial(transform_fn, inv_transforms) return optim_init(params), constrain_fn
def median(self, params): """ Returns the posterior median value of each latent variable. :param dict params: A dict containing parameter values. :return: A dict mapping sample site name to median tensor. :rtype: dict """ loc, _ = handlers.substitute(self._loc_scale, params)() return self._unpack_and_constrain(loc, params)
def get_transform(self, params): """ Returns the transformation learned by the guide to generate samples from the unconstrained (approximate) posterior. :param dict params: Current parameters of model and autoguide. :return: the transform of posterior distribution :rtype: :class:`~numpyro.distributions.constraints.Transform` """ return handlers.substitute(self._get_transform, params)()
def forecast(future, rng_key, sample, y, n_obs): Z_exp, obs_last = handlers.substitute(ar_k, sample)(n_obs, y) forecast_model = handlers.seed(_forecast, rng_key) forecast_trace = handlers.trace(forecast_model).get_trace( future, sample, Z_exp, n_obs) results = [ np.clip(forecast_trace["yf[{}]".format(t)]["value"], a_min=1e-30) for t in range(future) ] return np.stack(results, axis=0)
def get_log_probs(sample, seed=0): with handlers.trace() as tr, handlers.seed(model, seed), handlers.substitute( data=sample ): model(*model_args, **model_kwargs) return { name: site["fn"].log_prob(site["value"]) for name, site in tr.items() if site["type"] == "sample" }
def log_density(model, model_args, model_kwargs, params, skip_dist_transforms=False): """ Computes log of joint density for the model given latent values ``params``. :param model: Python callable containing NumPyro primitives. :param tuple model_args: args provided to the model. :param dict model_kwargs`: kwargs provided to the model. :param dict params: dictionary of current parameter values keyed by site name. :param bool skip_dist_transforms: whether to compute log probability of a site (if its prior is a transformed distribution) in its base distribution domain. :return: log of joint density and a corresponding model trace """ # We skip transforms in # + autoguide's model # + hmc's model # We apply transforms in # + autoguide's guide # + svi's model + guide if skip_dist_transforms: model = substitute(model, base_param_map=params) else: model = substitute(model, param_map=params) model_trace = trace(model).get_trace(*model_args, **model_kwargs) log_joint = 0. for site in model_trace.values(): if site['type'] == 'sample': value = site['value'] intermediates = site['intermediates'] if intermediates: if skip_dist_transforms: log_prob = site['fn'].base_dist.log_prob(intermediates[0][0]) else: log_prob = site['fn'].log_prob(value, intermediates) else: log_prob = site['fn'].log_prob(value) log_prob = np.sum(log_prob) if 'scale' in site: log_prob = site['scale'] * log_prob log_joint = log_joint + log_prob return log_joint, model_trace
def log_density(model, model_args, model_kwargs, params): model = substitute(model, params) model_trace = trace(model).get_trace(*model_args, **model_kwargs) log_joint = 0. for site in model_trace.values(): if site['type'] == 'sample': log_prob = np.sum(site['fn'].log_prob(site['value'])) if 'scale' in site: log_prob = site['scale'] * log_prob log_joint = log_joint + log_prob return log_joint, model_trace
def _bnn_predict(self, rng_key, samples, X): ''' This module takes the samples of a "trained" BNN and produces predictions based on the X values passed in ''' model = handlers.substitute( handlers.seed(self._model,rng_key), samples ) # Pass post. sampled parameters to model # Gather a trace over possible Y values given the model parameters and input value X model_trace = handlers.trace(model).get_trace(X=X, Y=None, D_H=self.bnn_dh,train=False) # Note: Y will be sampled in the model because we pass Y=None here return model_trace['Y']['value']
def test_subsample_substitute(): data = jnp.arange(100.) subsample_size = 7 subsample = jnp.array([13, 3, 30, 4, 1, 68, 5]) with handlers.trace() as tr, handlers.seed( rng_seed=0), handlers.substitute(data={"a": subsample}): with numpyro.plate("a", len(data), subsample_size=subsample_size) as idx: assert data[idx].shape == (subsample_size, ) assert_allclose(idx, subsample) assert tr["a"]["kwargs"]["rng_key"] is None
def _blr_predict(self, rng_key, samples, X, predict=False): ''' This module takes the samples of a "trained" BLR and produces predictions based on the provided X ''' model = handlers.substitute( handlers.seed(self._model, rng_key), samples) # Pass post. sampled parameters to the model # Gather a trace over possible Y values given the model parameters and input value X model_trace = handlers.trace(model).get_trace(X=X, Y=None, predict=predict) return model_trace['obs']['value']
def __init__(self, rng, model, get_params_fn, prefix="auto", init_loc_fn=init_to_median): # Wrap model in a `substitute` handler to initialize from `init_loc_fn`. # Use `block` to not record sample primitives in `init_loc_fn`. model = substitute(model, substitute_fn=block(seed(init_loc_fn, rng))) super(AutoContinuous, self).__init__(model, get_params_fn, prefix=prefix)
def log_likelihood(params_flat, subsample_indices=None): if subsample_indices is None: subsample_indices = {k: jnp.arange(v[0]) for k, v in subsample_plate_sizes.items()} params = unravel_fn(params_flat) with warnings.catch_warnings(): warnings.simplefilter("ignore") params = {name: biject_to(prototype_trace[name]["fn"].support)(value) for name, value in params.items()} with block(), trace() as tr, substitute(data=subsample_indices), substitute(data=params): model(*model_args, **model_kwargs) log_lik = {} for site in tr.values(): if site["type"] == "sample" and site["is_observed"]: for frame in site["cond_indep_stack"]: if frame.name in log_lik: log_lik[frame.name] += _sum_all_except_at_dim( site["fn"].log_prob(site["value"]), frame.dim) else: log_lik[frame.name] = _sum_all_except_at_dim( site["fn"].log_prob(site["value"]), frame.dim) return log_lik
def body_fn(state): i, key, _, _ = state key, subkey = random.split(key) if radius is None or prototype_params is None: # Wrap model in a `substitute` handler to initialize from `init_loc_fn`. seeded_model = substitute(seed(model, subkey), substitute_fn=init_strategy) model_trace = trace(seeded_model).get_trace( *model_args, **model_kwargs) constrained_values, inv_transforms = {}, {} for k, v in model_trace.items(): if (v["type"] == "sample" and not v["is_observed"] and not v["fn"].is_discrete): constrained_values[k] = v["value"] inv_transforms[k] = biject_to(v["fn"].support) params = transform_fn( inv_transforms, {k: v for k, v in constrained_values.items()}, invert=True, ) else: # this branch doesn't require tracing the model params = {} for k, v in prototype_params.items(): if k in init_values: params[k] = init_values[k] else: params[k] = random.uniform(subkey, jnp.shape(v), minval=-radius, maxval=radius) key, subkey = random.split(key) potential_fn = partial(potential_energy, model, model_args, model_kwargs, enum=enum) if validate_grad: if forward_mode_differentiation: pe = potential_fn(params) z_grad = jacfwd(potential_fn)(params) else: pe, z_grad = value_and_grad(potential_fn)(params) z_grad_flat = ravel_pytree(z_grad)[0] is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat)) else: pe = potential_fn(params) is_valid = jnp.isfinite(pe) z_grad = None return i + 1, key, (params, pe, z_grad), is_valid
def single_prediction(val): rng_key, samples = val model_trace = trace(seed(substitute(model, samples), rng_key)).get_trace( *model_args, **model_kwargs) if return_sites is not None: if return_sites == '': sites = {k for k, site in model_trace.items() if site['type'] != 'plate'} else: sites = return_sites else: sites = {k for k, site in model_trace.items() if (site['type'] == 'sample' and k not in samples) or (site['type'] == 'deterministic')} return {name: site['value'] for name, site in model_trace.items() if name in sites}
def log_density(model, model_args, model_kwargs, params): def logp(d, val): with validation_disabled(): return d.logpdf(val) if isinstance( d.dist, jax_continuous) else d.logpmf(val) model = substitute(model, params) model_trace = trace(model).get_trace(*model_args, **model_kwargs) log_joint = 0. for site in model_trace.values(): if site['type'] == 'sample': log_joint = log_joint + np.sum(logp(site['fn'], site['value'])) return log_joint, model_trace
def __call__(self, rng_key, *args, **kwargs): """ Returns dict of samples from the predictive distribution. By default, only sample sites not contained in `posterior_samples` are returned. This can be modified by changing the `return_sites` keyword argument of this :class:`Predictive` instance. :param jax.random.PRNGKey rng_key: random key to draw samples. :param args: model arguments. :param kwargs: model kwargs. """ posterior_samples = self.posterior_samples if self.guide is not None: rng_key, guide_rng_key = random.split(rng_key) # use return_sites='' as a special signal to return all sites guide = substitute(self.guide, self.params) posterior_samples = _predictive(guide_rng_key, guide, posterior_samples, self._batch_shape, return_sites='', parallel=self.parallel, model_args=args, model_kwargs=kwargs) model = substitute(self.model, self.params) return _predictive(rng_key, model, posterior_samples, self._batch_shape, return_sites=self.return_sites, parallel=self.parallel, model_args=args, model_kwargs=kwargs)
def sample_posterior(self, rng_key, params, sample_shape=()): """ Get samples from the learned posterior. :param jax.random.PRNGKey rng_key: random key to be used draw samples. :param dict params: Current parameters of model and autoguide. :param tuple sample_shape: batch shape of each latent sample, defaults to (). :return: a dict containing samples drawn the this guide. :rtype: dict """ latent_sample = handlers.substitute(handlers.seed(self._sample_latent, rng_key), params)( self.base_dist, sample_shape=sample_shape) return self._unpack_and_constrain(latent_sample, params)
def sample(self): rng_key, rng_key_sample, rng_key_accept = split( self.nmc_status.rng_key, 3) params = self.nmc_status.params for site in params.keys(): # Collect accepted trace for i in range(len(params[site])): self.acc_trace[site + str(i)].append(params[site][i]) tr_current = trace(substitute(self.model, params)).get_trace( *self.model_args, **self.model_kwargs) ll_current = self.nmc_status.log_likelihood val_current = tr_current[site]["value"] dist_curr = tr_current[site]["fn"] def log_den_fun(var): return partial(log_density, self.model, self.model_args, self.model_kwargs)(var)[0] val_proposal, dist_proposal = self.proposal( site, log_den_fun, self.get_params(tr_current), dist_curr, rng_key_sample) tr_proposal = self.retrace(site, tr_current, dist_proposal, val_proposal, self.model_args, self.model_kwargs) ll_proposal = log_density(self.model, self.model_args, self.model_kwargs, self.get_params(tr_proposal))[0] ll_proposal_val = dist_proposal.log_prob(val_current).sum() ll_current_val = dist_curr.log_prob(val_proposal).sum() hastings_ratio = (ll_proposal + ll_proposal_val) - \ (ll_current + ll_current_val) accept_prob = np.minimum(1, np.exp(hastings_ratio)) u = sample("u", dist.Uniform(0, 1), rng_key=rng_key_accept) if u <= accept_prob: params, ll_current = self.get_params(tr_proposal), ll_proposal else: params, ll_current = self.get_params(tr_current), ll_current iter = self.nmc_status.i + 1 mean_accept_prob = self.nmc_status.accept_prob + \ (accept_prob - self.nmc_status.accept_prob) / iter return NMC_STATUS(iter, params, ll_current, mean_accept_prob, rng_key)
def run_model_on_samples_and_data(modelfn, samples, data): assert type(samples) == dict assert len(samples) > 0 num_chains, num_samples = next(iter(samples.values())).shape[0:2] assert all(arr.shape[0:2] == (num_chains, num_samples) for arr in samples.values()) flat_samples = {k: flatten(arr) for k, arr in samples.items()} out = vmap(lambda sample: handler.substitute(modelfn, sample) (**data, mode='prior_and_mu'))(flat_samples) # Restore chain dim. return { k: unflatten(arr, num_chains, num_samples) for k, arr in out.items() }
def model(batch, subsample, full_size): drift = numpyro.sample("drift", dist.LogNormal(-1, 0.5)) with handlers.substitute(data={"data": subsample}): plate = numpyro.plate("data", full_size, subsample_size=len(subsample)) assert plate.size == 50 def transition_fn(z_prev, y_curr): with plate: z_curr = numpyro.sample("state", dist.Normal(z_prev, drift)) y_curr = numpyro.sample("obs", dist.Bernoulli(logits=z_curr), obs=y_curr) return z_curr, y_curr _, result = scan(transition_fn, jnp.zeros(len(subsample)), batch, length=num_time_steps) return result
def log_likelihood(params, subsample_indices=None): params_flat, unravel_fn = ravel_pytree(params) if subsample_indices is None: subsample_indices = { k: jnp.arange(v[0]) for k, v in subsample_plate_sizes.items() } params = unravel_fn(params_flat) with warnings.catch_warnings(): warnings.simplefilter("ignore") with block(), trace( ) as tr, substitute(data=subsample_indices), substitute( substitute_fn=partial(_unconstrain_reparam, params)): model(*model_args, **model_kwargs) log_lik = defaultdict(float) for site in tr.values(): if site["type"] == "sample" and site["is_observed"]: for frame in site["cond_indep_stack"]: if frame.name in subsample_plate_sizes: log_lik[frame.name] += _sum_all_except_at_dim( site["fn"].log_prob(site["value"]), frame.dim) return log_lik