Ejemplo n.º 1
0
    def _setup_prototype(self, *args, **kwargs):
        super(AutoContinuous, self)._setup_prototype(*args, **kwargs)
        rng_key = numpyro.sample("_{}_rng_key_init".format(self.prefix),
                                 dist.PRNGIdentity())
        init_params, _ = handlers.block(find_valid_initial_params)(
            rng_key,
            self.model,
            init_strategy=self.init_strategy,
            model_args=args,
            model_kwargs=kwargs)
        self._inv_transforms = {}
        self._has_transformed_dist = False
        unconstrained_sites = {}
        for name, site in self.prototype_trace.items():
            if site['type'] == 'sample' and not site['is_observed']:
                if site['intermediates']:
                    transform = biject_to(site['fn'].base_dist.support)
                    self._inv_transforms[name] = transform
                    unconstrained_sites[name] = transform.inv(
                        site['intermediates'][0][0])
                    self._has_transformed_dist = True
                else:
                    transform = biject_to(site['fn'].support)
                    self._inv_transforms[name] = transform
                    unconstrained_sites[name] = transform.inv(site['value'])

        self._init_latent, self._unpack_latent = ravel_pytree(init_params)
        self.latent_size = np.size(self._init_latent)
        if self.base_dist is None:
            self.base_dist = dist.Independent(
                dist.Normal(np.zeros(self.latent_size), 1.), 1)
        if self.latent_size == 0:
            raise RuntimeError(
                '{} found no latent variables; Use an empty guide instead'.
                format(type(self).__name__))
Ejemplo n.º 2
0
def test_independent_shape(jax_dist, sp_dist, params):
    d = jax_dist(*params)
    batch_shape, event_shape = d.batch_shape, d.event_shape
    shape = batch_shape + event_shape
    for i in range(len(batch_shape)):
        indep = dist.Independent(d, reinterpreted_batch_ndims=i)
        sample = indep.sample(random.PRNGKey(0))
        event_boundary = len(shape) - len(event_shape) - i
        assert indep.batch_shape == shape[:event_boundary]
        assert indep.event_shape == shape[event_boundary:]
        assert np.shape(indep.log_prob(sample)) == shape[:event_boundary]