def _setup_prototype(self, *args, **kwargs): super()._setup_prototype(*args, **kwargs) self._init_latent, unpack_latent = ravel_pytree(self._init_locs) # this is to match the behavior of Pyro, where we can apply # unpack_latent for a batch of samples self._unpack_latent = UnpackTransform(unpack_latent) self.latent_dim = jnp.size(self._init_latent) if self.latent_dim == 0: raise RuntimeError('{} found no latent variables; Use an empty guide instead' .format(type(self).__name__))
def get_transform(self, params): """ Returns the transformation learned by the guide to generate samples from the unconstrained (approximate) posterior. :param dict params: Current parameters of model and autoguide. :return: the transform of posterior distribution :rtype: :class:`~numpyro.distributions.transforms.Transform` """ return ComposeTransform([handlers.substitute(self._get_transform, params)(), UnpackTransform(self._unpack_latent)])
def _setup_prototype(self, *args, **kwargs): rng_key = numpyro.sample("_{}_rng_key_setup".format(self.prefix), dist.PRNGIdentity()) with handlers.block(): init_params, _, self._postprocess_fn, self.prototype_trace = initialize_model( rng_key, self.model, init_strategy=self.init_strategy, dynamic_args=False, model_args=args, model_kwargs=kwargs) self._init_latent, unpack_latent = ravel_pytree(init_params[0]) # this is to match the behavior of Pyro, where we can apply # unpack_latent for a batch of samples self._unpack_latent = UnpackTransform(unpack_latent) self.latent_dim = jnp.size(self._init_latent) if self.latent_dim == 0: raise RuntimeError('{} found no latent variables; Use an empty guide instead' .format(type(self).__name__))
def get_transform(self, params): return ComposeTransform([ self._get_transform(params), UnpackTransform(self._unpack_latent) ])