def unpack_single_latent(latent): unpacked_samples = self.unpack_latent(latent) if self._has_transformed_dist: base_param_map = {**params, **unpacked_samples} return constrain_fn(self.model, model_args, model_kwargs, self._inv_transforms, base_param_map) else: return transform_fn(self._inv_transforms, unpacked_samples)
def unpack_single_latent(latent): unpacked_samples = self.unpack_latent(latent) if self._has_transformed_dist: # first, substitute to `param` statements in model model = handlers.substitute(self.model, params) return constrain_fn(model, model_args, model_kwargs, self._inv_transforms, unpacked_samples) else: return transform_fn(self._inv_transforms, unpacked_samples)
def median(self, opt_state): """ Returns the posterior median value of each latent variable. :param opt_state: Current state of the optimizer. :return: A dict mapping sample site name to median tensor. :rtype: dict """ loc, _ = self._loc_scale(opt_state) return transform_fn(self._inv_transforms, self._unravel_fn(loc))
def _potential_energy(params): params_constrained = transform_fn(inv_transforms, params) log_joint, model_trace = log_density(model, model_args, model_kwargs, params_constrained) for name, t in inv_transforms.items(): t_log_det = np.sum( t.log_abs_det_jacobian(params[name], params_constrained[name])) if 'scale' in model_trace[name]: t_log_det = model_trace[name]['scale'] * t_log_det log_joint = log_joint + t_log_det return -log_joint
def sample_posterior(self, rng, opt_state, *args, **kwargs): sample_shape = kwargs.pop('sample_shape', ()) loc, scale = self._loc_scale(opt_state) num_samples = int(np.prod(sample_shape)) latent_sample = dist.Normal(loc, scale).sample(rng, sample_shape) if not sample_shape: unpacked_samples = self._unravel_fn(latent_sample) else: latent_sample = np.reshape( latent_sample, (num_samples, ) + np.shape(latent_sample)[len(sample_shape):]) unpacked_samples = vmap(self._unravel_fn)(latent_sample) unpacked_samples = { name: np.reshape(val, sample_shape + np.shape(val)[1:]) for name, val in unpacked_samples.items() } return transform_fn(self._inv_transforms, unpacked_samples)
def single_chain_init(key, only_params=False): seeded_model = seed(model, key) model_trace = trace(seeded_model).get_trace(*model_args, **model_kwargs) constrained_values, inv_transforms = {}, {} for k, v in model_trace.items(): if v['type'] == 'sample' and not v['is_observed']: constrained_values[k] = v['value'] inv_transforms[k] = biject_to(v['fn'].support) elif v['type'] == 'param': constrained_values[k] = v['value'] constraint = v['kwargs'].pop('constraint', real) inv_transforms[k] = biject_to(constraint) prior_params = transform_fn( inv_transforms, {k: v for k, v in constrained_values.items()}, invert=True) if init_strategy == 'uniform': init_params = {} for k, v in prior_params.items(): key, = random.split(key, 1) init_params[k] = random.uniform(key, shape=np.shape(v), minval=-2, maxval=2) elif init_strategy == 'prior': init_params = prior_params else: raise ValueError( 'initialize={} is not a valid initialization strategy.'.format( init_strategy)) if only_params: return init_params else: return (init_params, potential_energy(seeded_model, model_args, model_kwargs, inv_transforms), jax.partial(transform_fn, inv_transforms))
def initialize_model(rng, model, *model_args, init_strategy=init_to_uniform, **model_kwargs): """ Given a model with Pyro primitives, returns a function which, given unconstrained parameters, evaluates the potential energy (negative joint density). In addition, this also returns initial parameters sampled from the prior to initiate MCMC sampling and functions to transform unconstrained values at sample sites to constrained values within their respective support. :param jax.random.PRNGKey rng: random number generator seed to sample from the prior. The returned `init_params` will have the batch shape ``rng.shape[:-1]``. :param model: Python callable containing Pyro primitives. :param `*model_args`: args provided to the model. :param callable init_strategy: a per-site initialization function. :param `**model_kwargs`: kwargs provided to the model. :return: tuple of (`init_params`, `potential_fn`, `constrain_fn`), `init_params` are values from the prior used to initiate MCMC, `constrain_fn` is a callable that uses inverse transforms to convert unconstrained HMC samples to constrained values that lie within the site's support. """ seeded_model = seed(model, rng if rng.ndim == 1 else rng[0]) model_trace = trace(seeded_model).get_trace(*model_args, **model_kwargs) constrained_values, inv_transforms = {}, {} has_transformed_dist = False for k, v in model_trace.items(): if v['type'] == 'sample' and not v['is_observed']: if v['intermediates']: constrained_values[k] = v['intermediates'][0][0] inv_transforms[k] = biject_to(v['fn'].base_dist.support) has_transformed_dist = True else: constrained_values[k] = v['value'] inv_transforms[k] = biject_to(v['fn'].support) elif v['type'] == 'param': constraint = v['kwargs'].pop('constraint', real) transform = biject_to(constraint) if isinstance(transform, ComposeTransform): base_transform = transform.parts[0] constrained_values[k] = base_transform(transform.inv(v['value'])) inv_transforms[k] = base_transform has_transformed_dist = True else: inv_transforms[k] = transform constrained_values[k] = v['value'] prototype_params = transform_fn(inv_transforms, {k: v for k, v in constrained_values.items()}, invert=True) # NB: we use model instead of seeded_model to prevent unexpected behaviours (if any) potential_fn = jax.partial(potential_energy, model, model_args, model_kwargs, inv_transforms) if has_transformed_dist: # FIXME: why using seeded_model here triggers an error for funnel reparam example # if we use MCMC class (mcmc function works fine) constrain_fun = jax.partial(constrain_fn, model, model_args, model_kwargs, inv_transforms) else: constrain_fun = jax.partial(transform_fn, inv_transforms) def single_chain_init(key): return find_valid_initial_params(key, model, *model_args, init_strategy=init_strategy, param_as_improper=True, prototype_params=prototype_params, **model_kwargs) if rng.ndim == 1: init_params, is_valid = single_chain_init(rng) else: init_params, is_valid = lax.map(single_chain_init, rng) if isinstance(is_valid, jax.interpreters.xla.DeviceArray): if device_get(~np.all(is_valid)): raise RuntimeError("Cannot find valid initial parameters. Please check your model again.") return init_params, potential_fn, constrain_fun