def _setup_prototype(self, *args, **kwargs): super(AutoContinuous, self)._setup_prototype(*args, **kwargs) rng_key = numpyro.sample("_{}_rng_key_init".format(self.prefix), dist.PRNGIdentity()) init_params, _ = handlers.block(find_valid_initial_params)( rng_key, self.model, init_strategy=self.init_strategy, model_args=args, model_kwargs=kwargs) self._inv_transforms = {} self._has_transformed_dist = False unconstrained_sites = {} for name, site in self.prototype_trace.items(): if site['type'] == 'sample' and not site['is_observed']: if site['intermediates']: transform = biject_to(site['fn'].base_dist.support) self._inv_transforms[name] = transform unconstrained_sites[name] = transform.inv( site['intermediates'][0][0]) self._has_transformed_dist = True else: transform = biject_to(site['fn'].support) self._inv_transforms[name] = transform unconstrained_sites[name] = transform.inv(site['value']) self._init_latent, self._unpack_latent = ravel_pytree(init_params) self.latent_size = np.size(self._init_latent) if self.base_dist is None: self.base_dist = dist.Independent( dist.Normal(np.zeros(self.latent_size), 1.), 1) if self.latent_size == 0: raise RuntimeError( '{} found no latent variables; Use an empty guide instead'. format(type(self).__name__))
def test_independent_shape(jax_dist, sp_dist, params): d = jax_dist(*params) batch_shape, event_shape = d.batch_shape, d.event_shape shape = batch_shape + event_shape for i in range(len(batch_shape)): indep = dist.Independent(d, reinterpreted_batch_ndims=i) sample = indep.sample(random.PRNGKey(0)) event_boundary = len(shape) - len(event_shape) - i assert indep.batch_shape == shape[:event_boundary] assert indep.event_shape == shape[event_boundary:] assert np.shape(indep.log_prob(sample)) == shape[:event_boundary]