def test_model_with_transformed_distribution(): x_prior = dist.HalfNormal(2) y_prior = dist.LogNormal(scale=3.) # transformed distribution def model(): numpyro.sample('x', x_prior) numpyro.sample('y', y_prior) params = {'x': np.array(-5.), 'y': np.array(7.)} model = handlers.seed(model, random.PRNGKey(0)) inv_transforms = {'x': biject_to(x_prior.support), 'y': biject_to(y_prior.support)} expected_samples = partial(transform_fn, inv_transforms)(params) expected_potential_energy = ( - x_prior.log_prob(expected_samples['x']) - y_prior.log_prob(expected_samples['y']) - inv_transforms['x'].log_abs_det_jacobian(params['x'], expected_samples['x']) - inv_transforms['y'].log_abs_det_jacobian(params['y'], expected_samples['y']) ) base_inv_transforms = {'x': biject_to(x_prior.support), 'y': biject_to(y_prior.base_dist.support)} actual_samples = constrain_fn( handlers.seed(model, random.PRNGKey(0)), (), {}, base_inv_transforms, params) actual_potential_energy = potential_energy(model, (), {}, base_inv_transforms, params) assert_allclose(expected_samples['x'], actual_samples['x']) assert_allclose(expected_samples['y'], actual_samples['y']) assert_allclose(actual_potential_energy, expected_potential_energy)
def init_to_uniform(site, radius=2, skip_param=False): """ Initialize to an arbitrary feasible point, ignoring distribution parameters. """ if site['type'] == 'sample' and not site['is_observed']: if isinstance(site['fn'], dist.TransformedDistribution): fn = site['fn'].base_dist else: fn = site['fn'] value = numpyro.sample('_init', fn, sample_shape=site['kwargs']['sample_shape']) base_transform = biject_to(fn.support) unconstrained_value = numpyro.sample('_unconstrained_init', dist.Uniform(-radius, radius), sample_shape=np.shape( base_transform.inv(value))) return base_transform(unconstrained_value) if site['type'] == 'param' and not skip_param: # return base value of param site constraint = site['kwargs'].pop('constraint', real) transform = biject_to(constraint) value = site['args'][0] unconstrained_value = numpyro.sample('_unconstrained_init', dist.Uniform(-radius, radius), sample_shape=np.shape( transform.inv(value))) if isinstance(transform, ComposeTransform): base_transform = transform.parts[0] else: base_transform = transform return base_transform(unconstrained_value)
def _setup_prototype(self, *args, **kwargs): super(AutoContinuous, self)._setup_prototype(*args, **kwargs) rng = numpyro.sample("_{}_rng_init".format(self.prefix), dist.PRNGIdentity()) # FIXME: without block statement, get AssertionError: all sites must have unique names init_params, is_valid = handlers.block(find_valid_initial_params)( rng, self.model, *args, init_strategy=self.init_strategy, **kwargs) self._inv_transforms = {} self._has_transformed_dist = False unconstrained_sites = {} for name, site in self.prototype_trace.items(): if site['type'] == 'sample' and not site['is_observed']: if site['intermediates']: transform = biject_to(site['fn'].base_dist.support) self._inv_transforms[name] = transform unconstrained_sites[name] = transform.inv( site['intermediates'][0][0]) self._has_transformed_dist = True else: transform = biject_to(site['fn'].support) self._inv_transforms[name] = transform unconstrained_sites[name] = transform.inv(site['value']) self._init_latent, self.unpack_latent = ravel_pytree(init_params) self.latent_size = np.size(self._init_latent) if self.base_dist is None: self.base_dist = _Normal(np.zeros(self.latent_size), 1.) if self.latent_size == 0: raise RuntimeError( '{} found no latent variables; Use an empty guide instead'. format(type(self).__name__))
def initialize_model(rng, model, *model_args, init_strategy='uniform', **model_kwargs): """ Given a model with Pyro primitives, returns a function which, given unconstrained parameters, evaluates the potential energy (negative joint density). In addition, this also returns initial parameters sampled from the prior to initiate MCMC sampling and functions to transform unconstrained values at sample sites to constrained values within their respective support. :param jax.random.PRNGKey rng: random number generator seed to sample from the prior. :param model: Python callable containing Pyro primitives. :param `*model_args`: args provided to the model. :param str init_strategy: initialization strategy - `uniform` initializes the unconstrained parameters by drawing from a `Uniform(-2, 2)` distribution (as used by Stan), whereas `prior` initializes the parameters by sampling from the prior for each of the sample sites. :param `**model_kwargs`: kwargs provided to the model. :return: tuple of (`init_params`, `potential_fn`, `constrain_fn`) `init_params` are values from the prior used to initiate MCMC. `constrain_fn` is a callable that uses inverse transforms to convert unconstrained HMC samples to constrained values that lie within the site's support. """ model = seed(model, rng) model_trace = trace(model).get_trace(*model_args, **model_kwargs) sample_sites = { k: v for k, v in model_trace.items() if v['type'] == 'sample' and not v['is_observed'] } inv_transforms = { k: biject_to(v['fn'].support) for k, v in sample_sites.items() } prior_params = constrain_fn( inv_transforms, {k: v['value'] for k, v in sample_sites.items()}, invert=True) if init_strategy == 'uniform': init_params = {} for k, v in prior_params.items(): rng, = random.split(rng, 1) init_params[k] = random.uniform(rng, shape=np.shape(v), minval=-2, maxval=2) elif init_strategy == 'prior': init_params = prior_params else: raise ValueError( 'initialize={} is not a valid initialization strategy.'.format( init_strategy)) return init_params, potential_energy(model, model_args, model_kwargs, inv_transforms), \ jax.partial(constrain_fn, inv_transforms)
def test_log_prob_LKJCholesky(dimension, concentration): # We will test against the fact that LKJCorrCholesky can be seen as a # TransformedDistribution with base distribution is a distribution of partial # correlations in C-vine method (modulo an affine transform to change domain from (0, 1) # to (1, 0)) and transform is a signed stick-breaking process. d = dist.LKJCholesky(dimension, concentration, sample_method="cvine") beta_sample = d._beta.sample(random.PRNGKey(0)) beta_log_prob = np.sum(d._beta.log_prob(beta_sample)) partial_correlation = 2 * beta_sample - 1 affine_logdet = beta_sample.shape[-1] * np.log(2) sample = signed_stick_breaking_tril(partial_correlation) # compute signed stick breaking logdet inv_tanh = lambda t: np.log((1 + t) / (1 - t)) / 2 # noqa: E731 inv_tanh_logdet = np.sum(np.log(vmap(grad(inv_tanh))(partial_correlation))) unconstrained = inv_tanh(partial_correlation) corr_cholesky_logdet = biject_to( constraints.corr_cholesky).log_abs_det_jacobian( unconstrained, sample, ) signed_stick_breaking_logdet = corr_cholesky_logdet + inv_tanh_logdet actual_log_prob = d.log_prob(sample) expected_log_prob = beta_log_prob - affine_logdet - signed_stick_breaking_logdet assert_allclose(actual_log_prob, expected_log_prob, rtol=1e-5) assert_allclose(jax.jit(d.log_prob)(sample), d.log_prob(sample), atol=1e-7)
def test_bijective_transforms(transform, event_shape, batch_shape): shape = batch_shape + event_shape rng = random.PRNGKey(0) x = biject_to(transform.domain)(random.normal(rng, shape)) y = transform(x) # test codomain assert_array_equal(transform.codomain(y), np.ones(batch_shape)) # test inv z = transform.inv(y) assert_allclose(x, z, atol=1e-6, rtol=1e-6) # test domain assert_array_equal(transform.domain(z), np.ones(batch_shape)) # test log_abs_det_jacobian actual = transform.log_abs_det_jacobian(x, y) assert np.shape(actual) == batch_shape if len(shape) == transform.event_dim: if isinstance(transform, PermuteTransform): expected = onp.linalg.slogdet(jax.jacobian(transform)(x))[1] inv_expected = onp.linalg.slogdet(jax.jacobian( transform.inv)(y))[1] else: expected = np.log(np.abs(grad(transform)(x))) inv_expected = np.log(np.abs(grad(transform.inv)(y))) assert_allclose(actual, expected, atol=1e-6) assert_allclose(actual, -inv_expected, atol=1e-6)
def test_elbo_dynamic_support(): x_prior = dist.Uniform(0, 5) x_unconstrained = 2. def model(): numpyro.sample('x', x_prior) class _AutoGuide(AutoDiagonalNormal): def __call__(self, *args, **kwargs): return substitute( super(_AutoGuide, self).__call__, {'_auto_latent': x_unconstrained})(*args, **kwargs) adam = optim.Adam(0.01) guide = _AutoGuide(model) svi = SVI(model, guide, elbo, adam) svi_state = svi.init(random.PRNGKey(0), (), ()) actual_loss = svi.evaluate(svi_state) assert np.isfinite(actual_loss) guide_log_prob = dist.Normal( guide._init_latent).log_prob(x_unconstrained).sum() transfrom = constraints.biject_to(constraints.interval(0, 5)) x = transfrom(x_unconstrained) logdet = transfrom.log_abs_det_jacobian(x_unconstrained, x) model_log_prob = x_prior.log_prob(x) + logdet expected_loss = guide_log_prob - model_log_prob assert_allclose(actual_loss, expected_loss)
def test_biject_to(constraint, shape): transform = biject_to(constraint) if isinstance(constraint, constraints._Interval): assert transform.codomain.upper_bound == constraint.upper_bound assert transform.codomain.lower_bound == constraint.lower_bound elif isinstance(constraint, constraints._GreaterThan): assert transform.codomain.lower_bound == constraint.lower_bound if len(shape) < transform.event_dim: return rng = random.PRNGKey(0) x = random.normal(rng, shape) y = transform(x) # test codomain batch_shape = shape if transform.event_dim == 0 else shape[:-1] assert_array_equal(transform.codomain(y), np.ones(batch_shape, dtype=np.bool_)) # test inv z = transform.inv(y) assert_allclose(x, z, atol=1e-6, rtol=1e-6) # test domain, currently all is constraints.real or constraints.real_vector assert_array_equal(transform.domain(z), np.ones(batch_shape)) # test log_abs_det_jacobian actual = transform.log_abs_det_jacobian(x, y) assert np.shape(actual) == batch_shape if len(shape) == transform.event_dim: if constraint is constraints.simplex: expected = onp.linalg.slogdet( jax.jacobian(transform)(x)[:-1, :])[1] inv_expected = onp.linalg.slogdet( jax.jacobian(transform.inv)(y)[:, :-1])[1] elif constraint is constraints.corr_cholesky: vec_transform = lambda x: matrix_to_tril_vec( transform(x), diagonal=-1) # noqa: E731 y_tril = matrix_to_tril_vec(y, diagonal=-1) inv_vec_transform = lambda x: transform.inv( vec_to_tril_matrix(x, diagonal=-1)) # noqa: E731 expected = onp.linalg.slogdet(jax.jacobian(vec_transform)(x))[1] inv_expected = onp.linalg.slogdet( jax.jacobian(inv_vec_transform)(y_tril))[1] elif constraint is constraints.lower_cholesky: vec_transform = lambda x: matrix_to_tril_vec(transform(x) ) # noqa: E731 y_tril = matrix_to_tril_vec(y) inv_vec_transform = lambda x: transform.inv(vec_to_tril_matrix(x) ) # noqa: E731 expected = onp.linalg.slogdet(jax.jacobian(vec_transform)(x))[1] inv_expected = onp.linalg.slogdet( jax.jacobian(inv_vec_transform)(y_tril))[1] else: expected = np.log(np.abs(grad(transform)(x))) inv_expected = np.log(np.abs(grad(transform.inv)(y))) assert_allclose(actual, expected, atol=1e-6) assert_allclose(actual, -inv_expected, atol=1e-6)
def init_to_feasible(site): """ Initialize to an arbitrary feasible point, ignoring distribution parameters. """ if site['is_observed']: return None value = sample('_init', site['fn']) t = biject_to(site['fn'].support) return t(np.zeros(np.shape(t.inv(value))))
def body_fn(state): i, key, _, _ = state key, subkey = random.split(key) # Wrap model in a `substitute` handler to initialize from `init_loc_fn`. # Use `block` to not record sample primitives in `init_loc_fn`. seeded_model = substitute(model, substitute_fn=block( seed(init_strategy, subkey))) model_trace = trace(seeded_model).get_trace(*model_args, **model_kwargs) constrained_values, inv_transforms = {}, {} for k, v in model_trace.items(): if v['type'] == 'sample' and not v['is_observed']: if v['intermediates']: constrained_values[k] = v['intermediates'][0][0] inv_transforms[k] = biject_to(v['fn'].base_dist.support) else: constrained_values[k] = v['value'] inv_transforms[k] = biject_to(v['fn'].support) elif v['type'] == 'param' and param_as_improper: constraint = v['kwargs'].pop('constraint', real) transform = biject_to(constraint) if isinstance(transform, ComposeTransform): base_transform = transform.parts[0] inv_transforms[k] = base_transform constrained_values[k] = base_transform( transform.inv(v['value'])) else: inv_transforms[k] = transform constrained_values[k] = v['value'] params = transform_fn(inv_transforms, {k: v for k, v in constrained_values.items()}, invert=True) potential_fn = jax.partial(potential_energy, model, model_args, model_kwargs, inv_transforms) pe, param_grads = value_and_grad(potential_fn)(params) z_grad = ravel_pytree(param_grads)[0] is_valid = np.isfinite(pe) & np.all(np.isfinite(z_grad)) return i + 1, key, params, is_valid
def test_iaf(): # test for substitute logic for exposed methods `sample_posterior` and `get_transforms` N, dim = 3000, 3 data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = np.arange(1., dim + 1.) logits = np.sum(true_coefs * data, axis=-1) labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1)) def model(data, labels): coefs = numpyro.sample('coefs', dist.Normal(np.zeros(dim), np.ones(dim))) offset = numpyro.sample('offset', dist.Uniform(-1, 1)) logits = offset + np.sum(coefs * data, axis=-1) return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels) adam = optim.Adam(0.01) rng_init = random.PRNGKey(1) guide = AutoIAFNormal(model) svi = SVI(model, guide, elbo, adam) svi_state = svi.init(rng_init, model_args=(data, labels), guide_args=(data, labels)) params = svi.get_params(svi_state) x = random.normal(random.PRNGKey(0), (dim + 1, )) rng = random.PRNGKey(1) actual_sample = guide.sample_posterior(rng, params) actual_output = guide.get_transform(params)(x) flows = [] for i in range(guide.num_flows): if i > 0: flows.append(constraints.PermuteTransform( np.arange(dim + 1)[::-1])) arn_init, arn_apply = AutoregressiveNN( dim + 1, [dim + 1, dim + 1], permutation=np.arange(dim + 1), skip_connections=guide._skip_connections, nonlinearity=guide._nonlinearity) arn = partial(arn_apply, params['auto_arn__{}$params'.format(i)]) flows.append(InverseAutoregressiveTransform(arn)) transform = constraints.ComposeTransform(flows) rng_seed, rng_sample = random.split(rng) expected_sample = guide.unpack_latent( transform(dist.Normal(np.zeros(dim + 1), 1).sample(rng_sample))) expected_output = transform(x) assert_allclose(actual_sample['coefs'], expected_sample['coefs']) assert_allclose( actual_sample['offset'], constraints.biject_to(constraints.interval(-1, 1))( expected_sample['offset'])) assert_allclose(actual_output, expected_output)
def single_chain_init(key, only_params=False): seeded_model = seed(model, key) model_trace = trace(seeded_model).get_trace(*model_args, **model_kwargs) constrained_values, inv_transforms = {}, {} for k, v in model_trace.items(): if v['type'] == 'sample' and not v['is_observed']: constrained_values[k] = v['value'] inv_transforms[k] = biject_to(v['fn'].support) elif v['type'] == 'param': constrained_values[k] = v['value'] constraint = v['kwargs'].pop('constraint', real) inv_transforms[k] = biject_to(constraint) prior_params = transform_fn( inv_transforms, {k: v for k, v in constrained_values.items()}, invert=True) if init_strategy == 'uniform': init_params = {} for k, v in prior_params.items(): key, = random.split(key, 1) init_params[k] = random.uniform(key, shape=np.shape(v), minval=-2, maxval=2) elif init_strategy == 'prior': init_params = prior_params else: raise ValueError( 'initialize={} is not a valid initialization strategy.'.format( init_strategy)) if only_params: return init_params else: return (init_params, potential_energy(seeded_model, model_args, model_kwargs, inv_transforms), jax.partial(transform_fn, inv_transforms))
def process_message(self, msg): if self.param_map is not None: if msg['name'] in self.param_map: msg['value'] = self.param_map[msg['name']] elif self.base_param_map is not None: if msg['name'] in self.base_param_map: if msg['type'] == 'sample': msg['value'], msg['intermediates'] = msg[ 'fn'].transform_with_intermediates( self.base_param_map[msg['name']]) else: base_value = self.base_param_map[msg['name']] constraint = msg['kwargs'].pop('constraint', real) transform = biject_to(constraint) if isinstance(transform, ComposeTransform): msg['value'] = ComposeTransform( transform.parts[1:])(base_value) else: msg['value'] = self.base_param_map[msg['name']] elif self.substitute_fn is not None: base_value = self.substitute_fn(msg) if base_value is not None: if msg['type'] == 'sample': msg['value'], msg['intermediates'] = msg[ 'fn'].transform_with_intermediates(base_value) else: constraint = msg['kwargs'].pop('constraint', real) transform = biject_to(constraint) if isinstance(transform, ComposeTransform): msg['value'] = ComposeTransform( transform.parts[1:])(base_value) else: msg['value'] = base_value else: raise ValueError( "Neither `param_map`, `base_param_map`, nor `substitute_fn`" "provided to substitute handler.")
def init_fn(rng, model_args=(), guide_args=(), params=None): """ :param jax.random.PRNGKey rng: random number generator seed. :param tuple model_args: arguments to the model (these can possibly vary during the course of fitting). :param tuple guide_args: arguments to the guide (these can possibly vary during the course of fitting). :param dict params: initial parameter values to condition on. This can be useful for initializing neural networks using more specialized methods rather than sampling from the prior. :return: tuple containing initial optimizer state, and `constrain_fn`, a callable that transforms unconstrained parameter values from the optimizer to the specified constrained domain """ assert isinstance(model_args, tuple) assert isinstance(guide_args, tuple) model_init, guide_init = _seed(model, guide, rng) if params is None: params = {} else: model_init = substitute(model_init, params) guide_init = substitute(guide_init, params) guide_trace = trace(guide_init).get_trace(*guide_args, **kwargs) model_trace = trace(model_init).get_trace(*model_args, **kwargs) inv_transforms = {} # NB: params in model_trace will be overwritten by params in guide_trace for site in list(model_trace.values()) + list(guide_trace.values()): if site['type'] == 'param': constraint = site['kwargs'].pop('constraint', constraints.real) transform = biject_to(constraint) if isinstance(transform, ComposeTransform): base_transform = transform.parts[0] inv_transforms[site['name']] = base_transform params[site['name']] = base_transform( transform.inv(site['value'])) else: inv_transforms[site['name']] = transform params[site['name']] = site['value'] nonlocal constrain_fn constrain_fn = jax.partial(transform_fn, inv_transforms) return optim_init(params), constrain_fn
def _setup_prototype(self, *args, **kwargs): super(AutoContinuous, self)._setup_prototype(*args, **kwargs) self._inv_transforms = {} unconstrained_sites = {} for name, site in self.prototype_trace.items(): if site['type'] == 'sample' and not site['is_observed']: # Collect the shapes of unconstrained values. # These may differ from the shapes of constrained values. transform = biject_to(site['fn'].support) unconstrained_val = transform.inv(site['value']) self._inv_transforms[name] = transform unconstrained_sites[name] = unconstrained_val latent_size = sum(np.size(x) for x in unconstrained_sites.values()) if latent_size == 0: raise RuntimeError( '{} found no latent variables; Use an empty guide instead'. format(type(self).__name__)) self._init_latent, self._unravel_fn = ravel_pytree(unconstrained_sites)
def process_message(self, msg): if self.param_map is not None: if msg['name'] in self.param_map: msg['value'] = self.param_map[msg['name']] else: base_value = self.substitute_fn(msg) if self.substitute_fn \ else self.base_param_map.get(msg['name'], None) if base_value is not None: if msg['type'] == 'sample': msg['value'], msg['intermediates'] = msg[ 'fn'].transform_with_intermediates(base_value) else: constraint = msg['kwargs'].pop('constraint', real) transform = biject_to(constraint) if isinstance(transform, ComposeTransform): # No need to apply the first transform since the base value # should have the same support as the first part's co-domain. msg['value'] = ComposeTransform( transform.parts[1:])(base_value) else: msg['value'] = base_value
def init_to_median(site, num_samples=15, skip_param=False): """ Initialize to the prior median. """ if site['type'] == 'sample' and not site['is_observed']: if isinstance(site['fn'], dist.TransformedDistribution): fn = site['fn'].base_dist else: fn = site['fn'] samples = sample('_init', fn, sample_shape=(num_samples,)) return np.median(samples, axis=0) if site['type'] == 'param' and not skip_param: # return base value of param site constraint = site['kwargs'].pop('constraint', real) transform = biject_to(constraint) value = site['args'][0] if isinstance(transform, ComposeTransform): base_transform = transform.parts[0] value = base_transform(transform.inv(value)) return value
def _loc_scale(self, opt_state): params = self.get_params(opt_state) loc = params['{}_loc'.format(self.prefix)] scale = biject_to(constraints.positive)(params['{}_scale'.format( self.prefix)]) return loc, scale
def initialize_model(rng, model, *model_args, init_strategy=init_to_uniform, **model_kwargs): """ Given a model with Pyro primitives, returns a function which, given unconstrained parameters, evaluates the potential energy (negative joint density). In addition, this also returns initial parameters sampled from the prior to initiate MCMC sampling and functions to transform unconstrained values at sample sites to constrained values within their respective support. :param jax.random.PRNGKey rng: random number generator seed to sample from the prior. The returned `init_params` will have the batch shape ``rng.shape[:-1]``. :param model: Python callable containing Pyro primitives. :param `*model_args`: args provided to the model. :param callable init_strategy: a per-site initialization function. :param `**model_kwargs`: kwargs provided to the model. :return: tuple of (`init_params`, `potential_fn`, `constrain_fn`), `init_params` are values from the prior used to initiate MCMC, `constrain_fn` is a callable that uses inverse transforms to convert unconstrained HMC samples to constrained values that lie within the site's support. """ seeded_model = seed(model, rng if rng.ndim == 1 else rng[0]) model_trace = trace(seeded_model).get_trace(*model_args, **model_kwargs) constrained_values, inv_transforms = {}, {} has_transformed_dist = False for k, v in model_trace.items(): if v['type'] == 'sample' and not v['is_observed']: if v['intermediates']: constrained_values[k] = v['intermediates'][0][0] inv_transforms[k] = biject_to(v['fn'].base_dist.support) has_transformed_dist = True else: constrained_values[k] = v['value'] inv_transforms[k] = biject_to(v['fn'].support) elif v['type'] == 'param': constraint = v['kwargs'].pop('constraint', real) transform = biject_to(constraint) if isinstance(transform, ComposeTransform): base_transform = transform.parts[0] constrained_values[k] = base_transform(transform.inv(v['value'])) inv_transforms[k] = base_transform has_transformed_dist = True else: inv_transforms[k] = transform constrained_values[k] = v['value'] prototype_params = transform_fn(inv_transforms, {k: v for k, v in constrained_values.items()}, invert=True) # NB: we use model instead of seeded_model to prevent unexpected behaviours (if any) potential_fn = jax.partial(potential_energy, model, model_args, model_kwargs, inv_transforms) if has_transformed_dist: # FIXME: why using seeded_model here triggers an error for funnel reparam example # if we use MCMC class (mcmc function works fine) constrain_fun = jax.partial(constrain_fn, model, model_args, model_kwargs, inv_transforms) else: constrain_fun = jax.partial(transform_fn, inv_transforms) def single_chain_init(key): return find_valid_initial_params(key, model, *model_args, init_strategy=init_strategy, param_as_improper=True, prototype_params=prototype_params, **model_kwargs) if rng.ndim == 1: init_params, is_valid = single_chain_init(rng) else: init_params, is_valid = lax.map(single_chain_init, rng) if isinstance(is_valid, jax.interpreters.xla.DeviceArray): if device_get(~np.all(is_valid)): raise RuntimeError("Cannot find valid initial parameters. Please check your model again.") return init_params, potential_fn, constrain_fun