def get_model_transforms(rng_key, model, model_args=(), model_kwargs=None): model_kwargs = {} if model_kwargs is None else model_kwargs seeded_model = seed(model, rng_key if rng_key.ndim == 1 else rng_key[0]) model_trace = trace(seeded_model).get_trace(*model_args, **model_kwargs) inv_transforms = {} # model code may need to be replayed in the presence of dynamic constraints # or deterministic sites replay_model = False for k, v in model_trace.items(): if v['type'] == 'sample' and not v['is_observed']: if v['intermediates']: inv_transforms[k] = biject_to(v['fn'].base_dist.support) replay_model = True else: inv_transforms[k] = biject_to(v['fn'].support) elif v['type'] == 'param': constraint = v['kwargs'].pop('constraint', real) transform = biject_to(constraint) if isinstance(transform, ComposeTransform): inv_transforms[k] = transform.parts[0] replay_model = True else: inv_transforms[k] = transform elif v['type'] == 'deterministic': replay_model = True return inv_transforms, replay_model
def _setup_prototype(self, *args, **kwargs): super(AutoContinuous, self)._setup_prototype(*args, **kwargs) rng_key = numpyro.sample("_{}_rng_key_init".format(self.prefix), dist.PRNGIdentity()) init_params, _ = handlers.block(find_valid_initial_params)( rng_key, self.model, init_strategy=self.init_strategy, model_args=args, model_kwargs=kwargs) self._inv_transforms = {} self._has_transformed_dist = False unconstrained_sites = {} for name, site in self.prototype_trace.items(): if site['type'] == 'sample' and not site['is_observed']: if site['intermediates']: transform = biject_to(site['fn'].base_dist.support) self._inv_transforms[name] = transform unconstrained_sites[name] = transform.inv( site['intermediates'][0][0]) self._has_transformed_dist = True else: transform = biject_to(site['fn'].support) self._inv_transforms[name] = transform unconstrained_sites[name] = transform.inv(site['value']) self._init_latent, self._unpack_latent = ravel_pytree(init_params) self.latent_size = np.size(self._init_latent) if self.base_dist is None: self.base_dist = dist.Independent( dist.Normal(np.zeros(self.latent_size), 1.), 1) if self.latent_size == 0: raise RuntimeError( '{} found no latent variables; Use an empty guide instead'. format(type(self).__name__))
def body_fn(state): i, key, _, _ = state key, subkey = random.split(key) # Wrap model in a `substitute` handler to initialize from `init_loc_fn`. # Use `block` to not record sample primitives in `init_loc_fn`. seeded_model = substitute(model, substitute_fn=block(seed(init_strategy, subkey))) model_trace = trace(seeded_model).get_trace(*model_args, **model_kwargs) constrained_values, inv_transforms = {}, {} for k, v in model_trace.items(): if v['type'] == 'sample' and not v['is_observed']: if v['intermediates']: constrained_values[k] = v['intermediates'][0][0] inv_transforms[k] = biject_to(v['fn'].base_dist.support) else: constrained_values[k] = v['value'] inv_transforms[k] = biject_to(v['fn'].support) elif v['type'] == 'param' and param_as_improper: constraint = v['kwargs'].pop('constraint', real) transform = biject_to(constraint) if isinstance(transform, ComposeTransform): base_transform = transform.parts[0] inv_transforms[k] = base_transform constrained_values[k] = base_transform(transform.inv(v['value'])) else: inv_transforms[k] = transform constrained_values[k] = v['value'] params = transform_fn(inv_transforms, {k: v for k, v in constrained_values.items()}, invert=True) potential_fn = jax.partial(potential_energy, model, inv_transforms, model_args, model_kwargs) pe, param_grads = value_and_grad(potential_fn)(params) z_grad = ravel_pytree(param_grads)[0] is_valid = np.isfinite(pe) & np.all(np.isfinite(z_grad)) return i + 1, key, params, is_valid
def test_model_with_transformed_distribution(): x_prior = dist.HalfNormal(2) y_prior = dist.LogNormal(scale=3.) # transformed distribution def model(): numpyro.sample('x', x_prior) numpyro.sample('y', y_prior) params = {'x': jnp.array(-5.), 'y': jnp.array(7.)} model = handlers.seed(model, random.PRNGKey(0)) inv_transforms = { 'x': biject_to(x_prior.support), 'y': biject_to(y_prior.support) } expected_samples = partial(transform_fn, inv_transforms)(params) expected_potential_energy = (-x_prior.log_prob(expected_samples['x']) - y_prior.log_prob(expected_samples['y']) - inv_transforms['x'].log_abs_det_jacobian( params['x'], expected_samples['x']) - inv_transforms['y'].log_abs_det_jacobian( params['y'], expected_samples['y'])) reparam_model = handlers.reparam(model, {'y': TransformReparam()}) base_params = {'x': params['x'], 'y_base': params['y']} actual_samples = constrain_fn(handlers.seed(reparam_model, random.PRNGKey(0)), (), {}, base_params, return_deterministic=True) actual_potential_energy = potential_energy(reparam_model, (), {}, base_params) assert_allclose(expected_samples['x'], actual_samples['x']) assert_allclose(expected_samples['y'], actual_samples['y']) assert_allclose(actual_potential_energy, expected_potential_energy)
def get_potential_fn(rng_key, model, dynamic_args=False, model_args=(), model_kwargs=None): """ (EXPERIMENTAL INTERFACE) Given a model with Pyro primitives, returns a function which, given unconstrained parameters, evaluates the potential energy (negative log joint density). In addition, this returns a function to transform unconstrained values at sample sites to constrained values within their respective support. :param jax.random.PRNGKey rng_key: random number generator seed to sample from the prior. The returned `init_params` will have the batch shape ``rng_key.shape[:-1]``. :param model: Python callable containing Pyro primitives. :param bool dynamic_args: if `True`, the `potential_fn` and `constraints_fn` are themselves dependent on model arguments. When provided a `*model_args, **model_kwargs`, they return `potential_fn` and `constraints_fn` callables, respectively. :param tuple model_args: args provided to the model. :param dict model_kwargs: kwargs provided to the model. :return: tuple of (`potential_fn`, `constrain_fn`). The latter is used to constrain unconstrained samples (e.g. those returned by HMC) to values that lie within the site's support. """ model_kwargs = {} if model_kwargs is None else model_kwargs seeded_model = seed(model, rng_key if rng_key.ndim == 1 else rng_key[0]) model_trace = trace(seeded_model).get_trace(*model_args, **model_kwargs) inv_transforms = {} has_transformed_dist = False for k, v in model_trace.items(): if v['type'] == 'sample' and not v['is_observed']: if v['intermediates']: inv_transforms[k] = biject_to(v['fn'].base_dist.support) has_transformed_dist = True else: inv_transforms[k] = biject_to(v['fn'].support) elif v['type'] == 'param': constraint = v['kwargs'].pop('constraint', real) transform = biject_to(constraint) if isinstance(transform, ComposeTransform): inv_transforms[k] = transform.parts[0] has_transformed_dist = True else: inv_transforms[k] = transform if dynamic_args: def potential_fn(*args, **kwargs): return jax.partial(potential_energy, model, inv_transforms, args, kwargs) if has_transformed_dist: def constrain_fun(*args, **kwargs): return jax.partial(constrain_fn, model, inv_transforms, args, kwargs) else: def constrain_fun(*args, **kwargs): return jax.partial(transform_fn, inv_transforms) else: potential_fn = jax.partial(potential_energy, model, inv_transforms, model_args, model_kwargs) if has_transformed_dist: constrain_fun = jax.partial(constrain_fn, model, inv_transforms, model_args, model_kwargs) else: constrain_fun = jax.partial(transform_fn, inv_transforms) return potential_fn, constrain_fun
def _init_to_uniform(site, radius=2, skip_param=False): if site['type'] == 'sample' and not site['is_observed']: if isinstance(site['fn'], dist.TransformedDistribution): fn = site['fn'].base_dist else: fn = site['fn'] value = numpyro.sample('_init', fn, sample_shape=site['kwargs']['sample_shape']) base_transform = biject_to(fn.support) unconstrained_value = numpyro.sample('_unconstrained_init', dist.Uniform(-radius, radius), sample_shape=np.shape( base_transform.inv(value))) return base_transform(unconstrained_value) if site['type'] == 'param' and not skip_param: # return base value of param site constraint = site['kwargs'].pop('constraint', real) transform = biject_to(constraint) value = site['args'][0] unconstrained_value = numpyro.sample('_unconstrained_init', dist.Uniform(-radius, radius), sample_shape=np.shape( transform.inv(value))) if isinstance(transform, ComposeTransform): base_transform = transform.parts[0] else: base_transform = transform return base_transform(unconstrained_value)
def substitute_fn(site): if site["name"] in params: if site["type"] == "sample": with helpful_support_errors(site): return biject_to(site["fn"].support)(params[site["name"]]) else: return params[site["name"]]
def _unconstrain_reparam(params, site): name = site["name"] if name in params: p = params[name] support = site["fn"].support t = biject_to(support) # in scan, we might only want to substitute an item at index i, rather than the whole sequence i = site["infer"].get("_scan_current_index", None) if i is not None: event_dim_shift = t.codomain.event_dim - t.domain.event_dim expected_unconstrained_dim = len( site["fn"].shape()) - event_dim_shift # check if p has additional time dimension if jnp.ndim(p) > expected_unconstrained_dim: p = p[i] if support in [constraints.real, constraints.real_vector]: return p value = t(p) log_det = t.log_abs_det_jacobian(p, value) log_det = sum_rightmost( log_det, jnp.ndim(log_det) - jnp.ndim(value) + len(site["fn"].event_shape)) if site["scale"] is not None: log_det = site["scale"] * log_det numpyro.factor("_{}_log_det".format(name), log_det) return value
def _get_model_transforms(model, model_args=(), model_kwargs=None): model_kwargs = {} if model_kwargs is None else model_kwargs model_trace = trace(model).get_trace(*model_args, **model_kwargs) inv_transforms = {} # model code may need to be replayed in the presence of deterministic sites replay_model = False has_enumerate_support = False for k, v in model_trace.items(): if v["type"] == "sample" and not v["is_observed"]: if v["fn"].is_discrete: has_enumerate_support = True if not v["fn"].has_enumerate_support: raise RuntimeError( "MCMC only supports continuous sites or discrete sites " f"with enumerate support, but got {type(v['fn']).__name__}." ) else: support = v["fn"].support inv_transforms[k] = biject_to(support) # XXX: the following code filters out most situations with dynamic supports args = () if isinstance(support, constraints._GreaterThan): args = ("lower_bound", ) elif isinstance(support, constraints._Interval): args = ("lower_bound", "upper_bound") for arg in args: if not isinstance(getattr(support, arg), (int, float)): replay_model = True elif v["type"] == "deterministic": replay_model = True return inv_transforms, replay_model, has_enumerate_support, model_trace
def log_likelihood(params_flat, subsample_indices=None): if subsample_indices is None: subsample_indices = { k: jnp.arange(v[0]) for k, v in subsample_plate_sizes.items() } params = unravel_fn(params_flat) with warnings.catch_warnings(): warnings.simplefilter("ignore") params = { name: biject_to(prototype_trace[name]["fn"].support)(value) for name, value in params.items() } with block(), trace() as tr, substitute( data=subsample_indices), substitute(data=params): model(*model_args, **model_kwargs) log_lik = {} for site in tr.values(): if site["type"] == "sample" and site["is_observed"]: for frame in site["cond_indep_stack"]: if frame.name in log_lik: log_lik[frame.name] += _sum_all_except_at_dim( site["fn"].log_prob(site["value"]), frame.dim) else: log_lik[frame.name] = _sum_all_except_at_dim( site["fn"].log_prob(site["value"]), frame.dim) return log_lik
def test_elbo_dynamic_support(): x_prior = dist.Uniform(0, 5) x_unconstrained = 2. def model(): numpyro.sample('x', x_prior) class _AutoGuide(AutoDiagonalNormal): def __call__(self, *args, **kwargs): return substitute( super(_AutoGuide, self).__call__, {'_auto_latent': x_unconstrained})(*args, **kwargs) adam = optim.Adam(0.01) guide = _AutoGuide(model) svi = SVI(model, guide, adam, AutoContinuousELBO()) svi_state = svi.init(random.PRNGKey(0)) actual_loss = svi.evaluate(svi_state) assert np.isfinite(actual_loss) guide_log_prob = dist.Normal( guide._init_latent, guide._init_scale).log_prob(x_unconstrained).sum() transfrom = transforms.biject_to(constraints.interval(0, 5)) x = transfrom(x_unconstrained) logdet = transfrom.log_abs_det_jacobian(x_unconstrained, x) model_log_prob = x_prior.log_prob(x) + logdet expected_loss = guide_log_prob - model_log_prob assert_allclose(actual_loss, expected_loss, rtol=1e-6)
def body_fn(state): i, key, _, _ = state key, subkey = random.split(key) if radius is None or prototype_params is None: # Wrap model in a `substitute` handler to initialize from `init_loc_fn`. seeded_model = substitute(seed(model, subkey), substitute_fn=init_strategy) model_trace = trace(seeded_model).get_trace(*model_args, **model_kwargs) constrained_values, inv_transforms = {}, {} for k, v in model_trace.items(): if v['type'] == 'sample' and not v['is_observed'] and not v['fn'].is_discrete: constrained_values[k] = v['value'] inv_transforms[k] = biject_to(v['fn'].support) params = transform_fn(inv_transforms, {k: v for k, v in constrained_values.items()}, invert=True) else: # this branch doesn't require tracing the model params = {} for k, v in prototype_params.items(): if k in init_values: params[k] = init_values[k] else: params[k] = random.uniform(subkey, jnp.shape(v), minval=-radius, maxval=radius) key, subkey = random.split(key) potential_fn = partial(potential_energy, model, model_args, model_kwargs, enum=enum) pe, z_grad = value_and_grad(potential_fn)(params) z_grad_flat = ravel_pytree(z_grad)[0] is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat)) return i + 1, key, (params, pe, z_grad), is_valid
def __call__(self, *args, **kwargs): """ An automatic guide with the same ``*args, **kwargs`` as the base ``model``. :return: A dict mapping sample site name to sampled value. :rtype: dict """ if self.prototype_trace is None: # run model to inspect the model structure self._setup_prototype(*args, **kwargs) latent = self._sample_latent(*args, **kwargs) # unpack continuous latent samples result = {} for name, unconstrained_value in self._unpack_latent(latent).items(): site = self.prototype_trace[name] transform = biject_to(site["fn"].support) value = transform(unconstrained_value) event_ndim = site["fn"].event_dim if numpyro.get_mask() is False: log_density = 0.0 else: log_density = -transform.log_abs_det_jacobian( unconstrained_value, value) log_density = sum_rightmost( log_density, jnp.ndim(log_density) - jnp.ndim(value) + event_ndim) delta_dist = dist.Delta(value, log_density=log_density, event_dim=event_ndim) result[name] = numpyro.sample(name, delta_dist) return result
def test_log_prob_LKJCholesky(dimension, concentration): # We will test against the fact that LKJCorrCholesky can be seen as a # TransformedDistribution with base distribution is a distribution of partial # correlations in C-vine method (modulo an affine transform to change domain from (0, 1) # to (1, 0)) and transform is a signed stick-breaking process. d = dist.LKJCholesky(dimension, concentration, sample_method="cvine") beta_sample = d._beta.sample(random.PRNGKey(0)) beta_log_prob = np.sum(d._beta.log_prob(beta_sample)) partial_correlation = 2 * beta_sample - 1 affine_logdet = beta_sample.shape[-1] * np.log(2) sample = signed_stick_breaking_tril(partial_correlation) # compute signed stick breaking logdet inv_tanh = lambda t: np.log((1 + t) / (1 - t)) / 2 # noqa: E731 inv_tanh_logdet = np.sum(np.log(vmap(grad(inv_tanh))(partial_correlation))) unconstrained = inv_tanh(partial_correlation) corr_cholesky_logdet = biject_to(constraints.corr_cholesky).log_abs_det_jacobian( unconstrained, sample, ) signed_stick_breaking_logdet = corr_cholesky_logdet + inv_tanh_logdet actual_log_prob = d.log_prob(sample) expected_log_prob = beta_log_prob - affine_logdet - signed_stick_breaking_logdet assert_allclose(actual_log_prob, expected_log_prob, rtol=1e-5) assert_allclose(jax.jit(d.log_prob)(sample), d.log_prob(sample), atol=1e-7)
def test_bijective_transforms(transform, event_shape, batch_shape): shape = batch_shape + event_shape rng_key = random.PRNGKey(0) x = biject_to(transform.domain)(random.normal(rng_key, shape)) y = transform(x) # test codomain assert_array_equal(transform.codomain(y), np.ones(batch_shape)) # test inv z = transform.inv(y) assert_allclose(x, z, atol=1e-6, rtol=1e-6) # test domain assert_array_equal(transform.domain(z), np.ones(batch_shape)) # test log_abs_det_jacobian actual = transform.log_abs_det_jacobian(x, y) assert np.shape(actual) == batch_shape if len(shape) == transform.event_dim: if len(event_shape) == 1: expected = onp.linalg.slogdet(jax.jacobian(transform)(x))[1] inv_expected = onp.linalg.slogdet(jax.jacobian(transform.inv)(y))[1] else: expected = np.log(np.abs(grad(transform)(x))) inv_expected = np.log(np.abs(grad(transform.inv)(y))) assert_allclose(actual, expected, atol=1e-6) assert_allclose(actual, -inv_expected, atol=1e-6)
def init(self, rng_key, *args, **kwargs): """ Gets the initial SVI state. :param jax.random.PRNGKey rng_key: random number generator seed. :param args: arguments to the model / guide (these can possibly vary during the course of fitting). :param kwargs: keyword arguments to the model / guide (these can possibly vary during the course of fitting). :return: the initial :data:`SVIState` """ rng_key, model_seed, guide_seed = random.split(rng_key, 3) model_init = seed(self.model, model_seed) guide_init = seed(self.guide, guide_seed) guide_trace = trace(guide_init).get_trace(*args, **kwargs, **self.static_kwargs) model_trace = trace(replay(model_init, guide_trace)).get_trace( *args, **kwargs, **self.static_kwargs) params = {} inv_transforms = {} # NB: params in model_trace will be overwritten by params in guide_trace for site in list(model_trace.values()) + list(guide_trace.values()): if site['type'] == 'param': constraint = site['kwargs'].pop('constraint', constraints.real) transform = biject_to(constraint) inv_transforms[site['name']] = transform params[site['name']] = transform.inv(site['value']) self.constrain_fn = partial(transform_fn, inv_transforms) # we convert weak types like float to float32/float64 # to avoid recompiling body_fn in svi.run params = tree_map( lambda x: lax.convert_element_type(x, jnp.result_type(x)), params) return SVIState(self.optim.init(params), rng_key)
def body_fn(state): i, key, _, _ = state key, subkey = random.split(key) if radius is None or prototype_params is None: # XXX: we don't want to apply enum to draw latent samples model_ = model if enum: from numpyro.contrib.funsor import enum as enum_handler if isinstance(model, substitute) and isinstance(model.fn, enum_handler): model_ = substitute(model.fn.fn, data=model.data) elif isinstance(model, enum_handler): model_ = model.fn # Wrap model in a `substitute` handler to initialize from `init_loc_fn`. seeded_model = substitute(seed(model_, subkey), substitute_fn=init_strategy) model_trace = trace(seeded_model).get_trace(*model_args, **model_kwargs) constrained_values, inv_transforms = {}, {} for k, v in model_trace.items(): if ( v["type"] == "sample" and not v["is_observed"] and not v["fn"].is_discrete ): constrained_values[k] = v["value"] inv_transforms[k] = biject_to(v["fn"].support) params = transform_fn( inv_transforms, {k: v for k, v in constrained_values.items()}, invert=True, ) else: # this branch doesn't require tracing the model params = {} for k, v in prototype_params.items(): if k in init_values: params[k] = init_values[k] else: params[k] = random.uniform( subkey, jnp.shape(v), minval=-radius, maxval=radius ) key, subkey = random.split(key) potential_fn = partial( potential_energy, model, model_args, model_kwargs, enum=enum ) if validate_grad: if forward_mode_differentiation: pe = potential_fn(params) z_grad = jacfwd(potential_fn)(params) else: pe, z_grad = value_and_grad(potential_fn)(params) z_grad_flat = ravel_pytree(z_grad)[0] is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat)) else: pe = potential_fn(params) is_valid = jnp.isfinite(pe) z_grad = None return i + 1, key, (params, pe, z_grad), is_valid
def test_initialize_model_change_point(init_strategy): def model(data): alpha = 1 / jnp.mean(data) lambda1 = numpyro.sample('lambda1', dist.Exponential(alpha)) lambda2 = numpyro.sample('lambda2', dist.Exponential(alpha)) tau = numpyro.sample('tau', dist.Uniform(0, 1)) lambda12 = jnp.where(jnp.arange(len(data)) < tau * len(data), lambda1, lambda2) numpyro.sample('obs', dist.Poisson(lambda12), obs=data) count_data = jnp.array([ 13, 24, 8, 24, 7, 35, 14, 11, 15, 11, 22, 22, 11, 57, 11, 19, 29, 6, 19, 12, 22, 12, 18, 72, 32, 9, 7, 13, 19, 23, 27, 20, 6, 17, 13, 10, 14, 6, 16, 15, 7, 2, 15, 15, 19, 70, 49, 7, 53, 22, 21, 31, 19, 11, 18, 20, 12, 35, 17, 23, 17, 4, 2, 31, 30, 13, 27, 0, 39, 37, 5, 14, 13, 22, ]) rng_keys = random.split(random.PRNGKey(1), 2) init_params, _, _, _ = initialize_model(rng_keys, model, init_strategy=init_strategy, model_args=(count_data,)) if isinstance(init_strategy, partial) and init_strategy.func is init_to_value: expected = biject_to(constraints.unit_interval).inv(init_strategy.keywords.get('values')['tau']) assert_allclose(init_params[0]['tau'], jnp.repeat(expected, 2)) for i in range(2): init_params_i, _, _, _ = initialize_model(rng_keys[i], model, init_strategy=init_strategy, model_args=(count_data,)) for name, p in init_params[0].items(): # XXX: the result is equal if we disable fast-math-mode assert_allclose(p[i], init_params_i[0][name], atol=1e-6)
def init(self, rng_key, *args, **kwargs): """ :param jax.random.PRNGKey rng_key: random number generator seed. :param args: arguments to the model / guide (these can possibly vary during the course of fitting). :param kwargs: keyword arguments to the model / guide (these can possibly vary during the course of fitting). :return: tuple containing initial :data:`SVIState`, and `get_params`, a callable that transforms unconstrained parameter values from the optimizer to the specified constrained domain """ rng_key, model_seed, guide_seed = random.split(rng_key, 3) model_init = seed(self.model, model_seed) guide_init = seed(self.guide, guide_seed) guide_trace = trace(guide_init).get_trace(*args, **kwargs, **self.static_kwargs) model_trace = trace(replay(model_init, guide_trace)).get_trace(*args, **kwargs, **self.static_kwargs) params = {} inv_transforms = {} # NB: params in model_trace will be overwritten by params in guide_trace for site in list(model_trace.values()) + list(guide_trace.values()): if site['type'] == 'param': constraint = site['kwargs'].pop('constraint', constraints.real) transform = biject_to(constraint) inv_transforms[site['name']] = transform params[site['name']] = transform.inv(site['value']) self.constrain_fn = partial(transform_fn, inv_transforms) return SVIState(self.optim.init(params), rng_key)
def __call__(self, *args, **kwargs): """ An automatic guide with the same ``*args, **kwargs`` as the base ``model``. :return: A dict mapping sample site name to sampled value. :rtype: dict """ if self.prototype_trace is None: # run model to inspect the model structure self._setup_prototype(*args, **kwargs) plates = self._create_plates(*args, **kwargs) result = {} for name, site in self.prototype_trace.items(): if site["type"] != "sample" or isinstance( site["fn"], dist.PRNGIdentity) or site["is_observed"]: continue event_dim = self._event_dims[name] init_loc = self._init_locs[name] with ExitStack() as stack: for frame in site["cond_indep_stack"]: stack.enter_context(plates[frame.name]) site_loc = numpyro.param("{}_{}_loc".format(name, self.prefix), init_loc, event_dim=event_dim) site_scale = numpyro.param("{}_{}_scale".format( name, self.prefix), jnp.full(jnp.shape(init_loc), self._init_scale), constraint=constraints.positive, event_dim=event_dim) site_fn = dist.Normal(site_loc, site_scale).to_event(event_dim) if site["fn"].support in [ constraints.real, constraints.real_vector ]: result[name] = numpyro.sample(name, site_fn) else: unconstrained_value = numpyro.sample( "{}_unconstrained".format(name), site_fn, infer={"is_auxiliary": True}) transform = biject_to(site['fn'].support) value = transform(unconstrained_value) log_density = -transform.log_abs_det_jacobian( unconstrained_value, value) log_density = sum_rightmost( log_density, jnp.ndim(log_density) - jnp.ndim(value) + site["fn"].event_dim) delta_dist = dist.Delta(value, log_density=log_density, event_dim=site["fn"].event_dim) result[name] = numpyro.sample(name, delta_dist) return result
def init(site, skip_param=False): if isinstance(site['fn'], dist.TransformedDistribution): fn = site['fn'].base_dist else: fn = site['fn'] vals = init_strategy(site, skip_param=skip_param) if vals is not None: base_transform = biject_to(fn.support) unconstrained_init = numpyro.sample('_noisy_init', dist.Normal(loc=base_transform.inv(vals), scale=noise_scale)) return base_transform(unconstrained_init)
def test_iaf(): # test for substitute logic for exposed methods `sample_posterior` and `get_transforms` N, dim = 3000, 3 data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = jnp.arange(1.0, dim + 1.0) logits = jnp.sum(true_coefs * data, axis=-1) labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1)) def model(data, labels): coefs = numpyro.sample("coefs", dist.Normal(jnp.zeros(dim), jnp.ones(dim))) offset = numpyro.sample("offset", dist.Uniform(-1, 1)) logits = offset + jnp.sum(coefs * data, axis=-1) return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels) adam = optim.Adam(0.01) rng_key_init = random.PRNGKey(1) guide = AutoIAFNormal(model) svi = SVI(model, guide, adam, Trace_ELBO()) svi_state = svi.init(rng_key_init, data, labels) params = svi.get_params(svi_state) x = random.normal(random.PRNGKey(0), (dim + 1, )) rng_key = random.PRNGKey(1) actual_sample = guide.sample_posterior(rng_key, params) actual_output = guide._unpack_latent(guide.get_transform(params)(x)) flows = [] for i in range(guide.num_flows): if i > 0: flows.append(transforms.PermuteTransform( jnp.arange(dim + 1)[::-1])) arn_init, arn_apply = AutoregressiveNN( dim + 1, [dim + 1, dim + 1], permutation=jnp.arange(dim + 1), skip_connections=guide._skip_connections, nonlinearity=guide._nonlinearity, ) arn = partial(arn_apply, params["auto_arn__{}$params".format(i)]) flows.append(InverseAutoregressiveTransform(arn)) flows.append(guide._unpack_latent) transform = transforms.ComposeTransform(flows) _, rng_key_sample = random.split(rng_key) expected_sample = transform( dist.Normal(jnp.zeros(dim + 1), 1).sample(rng_key_sample)) expected_output = transform(x) assert_allclose(actual_sample["coefs"], expected_sample["coefs"]) assert_allclose( actual_sample["offset"], transforms.biject_to(constraints.interval(-1, 1))( expected_sample["offset"]), ) check_eq(actual_output, expected_output)
def init(self, rng_key, *args, **kwargs): """ Gets the initial SVI state. :param jax.random.PRNGKey rng_key: random number generator seed. :param args: arguments to the model / guide (these can possibly vary during the course of fitting). :param kwargs: keyword arguments to the model / guide (these can possibly vary during the course of fitting). :return: the initial :data:`SVIState` """ rng_key, model_seed, guide_seed = random.split(rng_key, 3) model_init = seed(self.model, model_seed) guide_init = seed(self.guide, guide_seed) guide_trace = trace(guide_init).get_trace(*args, **kwargs, **self.static_kwargs) model_trace = trace(replay(model_init, guide_trace)).get_trace( *args, **kwargs, **self.static_kwargs ) params = {} inv_transforms = {} mutable_state = {} # NB: params in model_trace will be overwritten by params in guide_trace for site in list(model_trace.values()) + list(guide_trace.values()): if site["type"] == "param": constraint = site["kwargs"].pop("constraint", constraints.real) with helpful_support_errors(site): transform = biject_to(constraint) inv_transforms[site["name"]] = transform params[site["name"]] = transform.inv(site["value"]) elif site["type"] == "mutable": mutable_state[site["name"]] = site["value"] elif ( site["type"] == "sample" and (not site["is_observed"]) and site["fn"].support.is_discrete and not self.loss.can_infer_discrete ): s_name = type(self.loss).__name__ warnings.warn( f"Currently, SVI with {s_name} loss does not support models with discrete latent variables" ) if not mutable_state: mutable_state = None self.constrain_fn = partial(transform_fn, inv_transforms) # we convert weak types like float to float32/float64 # to avoid recompiling body_fn in svi.run params, mutable_state = tree_map( lambda x: lax.convert_element_type(x, jnp.result_type(x)), (params, mutable_state), ) return SVIState(self.optim.init(params), mutable_state, rng_key)
def find_params(self, rng_keys, *args, **kwargs): params = {} init_params, _ = handlers.block(find_valid_initial_params)(rng_keys, self.model, init_strategy=self.init_strategy, model_args=args, model_kwargs=kwargs) for name, site in self.prototype_trace.items(): if site['type'] == 'sample' and not site['is_observed']: param_name = "{}_{}".format(self.prefix, name) param_val = biject_to(site['fn'].support)(init_params[name]) params[name] = (param_name, param_val, site['fn'].support) self._param_map = params self._init_params = {param: (val, constr) for param, val, constr in self._param_map.values()}
def get_model_transforms(rng_key, model, model_args=(), model_kwargs=None): model_kwargs = {} if model_kwargs is None else model_kwargs seeded_model = seed(model, rng_key if rng_key.ndim == 1 else rng_key[0]) model_trace = trace(seeded_model).get_trace(*model_args, **model_kwargs) inv_transforms = {} has_transformed_dist = False for k, v in model_trace.items(): if v['type'] == 'sample' and not v['is_observed']: if v['intermediates']: inv_transforms[k] = biject_to(v['fn'].base_dist.support) has_transformed_dist = True else: inv_transforms[k] = biject_to(v['fn'].support) elif v['type'] == 'param': constraint = v['kwargs'].pop('constraint', real) transform = biject_to(constraint) if isinstance(transform, ComposeTransform): inv_transforms[k] = transform.parts[0] has_transformed_dist = True else: inv_transforms[k] = transform return inv_transforms, has_transformed_dist
def _unconstrain_reparam(params, site): name = site['name'] if name in params: p = params[name] t = biject_to(site['fn'].support) value = t(p) log_det = t.log_abs_det_jacobian(p, value) log_det = sum_rightmost( log_det, jnp.ndim(log_det) - jnp.ndim(value) + len(site['fn'].event_shape)) if site['scale'] is not None: log_det = site['scale'] * log_det numpyro.factor('_{}_log_det'.format(name), log_det) return value
def _get_model_transforms(model, model_args=(), model_kwargs=None): model_kwargs = {} if model_kwargs is None else model_kwargs model_trace = trace(model).get_trace(*model_args, **model_kwargs) inv_transforms = {} # model code may need to be replayed in the presence of deterministic sites replay_model = False has_enumerate_support = False for k, v in model_trace.items(): if v["type"] == "sample" and not v["is_observed"]: if v["fn"].support.is_discrete: enum_type = v["infer"].get("enumerate") if enum_type is not None and (enum_type != "parallel"): raise RuntimeError( "This algorithm might only work for discrete sites with" f" enumerate marked 'parallel'. But the site {k} is marked" f" as '{enum_type}'.") has_enumerate_support = True if not v["fn"].has_enumerate_support: dist_name = type(v["fn"]).__name__ raise RuntimeError( "This algorithm might only work for discrete sites with" f" enumerate support. But the {dist_name} distribution at" f" site {k} does not have enumerate support.") if enum_type is None: warnings.warn( "Some algorithms will automatically enumerate the discrete" f" latent site {k} of your model. In the future," " enumerated sites need to be marked with" " `infer={'enumerate': 'parallel'}`.", FutureWarning, stacklevel=find_stack_level(), ) else: support = v["fn"].support with helpful_support_errors(v, raise_warnings=True): inv_transforms[k] = biject_to(support) # XXX: the following code filters out most situations with dynamic supports args = () if isinstance(support, constraints._GreaterThan): args = ("lower_bound", ) elif isinstance(support, constraints._Interval): args = ("lower_bound", "upper_bound") for arg in args: if not isinstance(getattr(support, arg), (int, float)): replay_model = True elif v["type"] == "deterministic": replay_model = True return inv_transforms, replay_model, has_enumerate_support, model_trace
def __call__(self, *args, **kwargs): """ An automatic guide with the same ``*args, **kwargs`` as the base ``model``. :return: A dict mapping sample site name to sampled value. :rtype: dict """ if self.prototype_trace is None: # run model to inspect the model structure self._setup_prototype(*args, **kwargs) plates = self._create_plates(*args, **kwargs) result = {} for name, site in self.prototype_trace.items(): if site["type"] != "sample" or site["is_observed"]: continue event_dim = self._event_dims[name] init_loc = self._init_locs[name] with ExitStack() as stack: for frame in site["cond_indep_stack"]: stack.enter_context(plates[frame.name]) site_loc = numpyro.param("{}_{}_loc".format(name, self.prefix), init_loc, event_dim=event_dim) site_scale = numpyro.param( "{}_{}_scale".format(name, self.prefix), jnp.full(jnp.shape(init_loc), self._init_scale), constraint=self.scale_constraint, event_dim=event_dim, ) site_fn = dist.Normal(site_loc, site_scale).to_event(event_dim) if site["fn"].support is constraints.real or ( isinstance(site["fn"].support, constraints.independent) and site["fn"].support is constraints.real): result[name] = numpyro.sample(name, site_fn) else: transform = biject_to(site["fn"].support) guide_dist = dist.TransformedDistribution( site_fn, transform) result[name] = numpyro.sample(name, guide_dist) return result
def _init_to_median(site, num_samples=15, skip_param=False): if site['type'] == 'sample' and not site['is_observed']: if isinstance(site['fn'], dist.TransformedDistribution): fn = site['fn'].base_dist else: fn = site['fn'] samples = numpyro.sample('_init', fn, sample_shape=(num_samples,) + site['kwargs']['sample_shape']) return np.median(samples, axis=0) if site['type'] == 'param' and not skip_param: # return base value of param site constraint = site['kwargs'].pop('constraint', real) transform = biject_to(constraint) value = site['args'][0] if isinstance(transform, ComposeTransform): base_transform = transform.parts[0] value = base_transform(transform.inv(value)) return value
def _init_to_value(site, values={}, skip_param=False): if site['type'] == 'sample' and not site['is_observed']: if site['name'] not in values: return _init_to_uniform(site, skip_param=skip_param) value = values[site['name']] if isinstance(site['fn'], dist.TransformedDistribution): value = ComposeTransform(site['fn'].transforms).inv(value) return value if site['type'] == 'param' and not skip_param: # return base value of param site constraint = site['kwargs'].pop('constraint', real) transform = biject_to(constraint) value = site['args'][0] if isinstance(transform, ComposeTransform): base_transform = transform.parts[0] value = base_transform(transform.inv(value)) return value