Ejemplo n.º 1
0
def _unconstrain_reparam(params, site):
    name = site['name']
    if name in params:
        p = params[name]
        support = site['fn'].support
        if support in [real, real_vector]:
            return p
        t = biject_to(support)
        value = t(p)

        log_det = t.log_abs_det_jacobian(p, value)
        log_det = sum_rightmost(
            log_det,
            jnp.ndim(log_det) - jnp.ndim(value) + len(site['fn'].event_shape))
        if site['scale'] is not None:
            log_det = site['scale'] * log_det
        numpyro.factor('_{}_log_det'.format(name), log_det)
        return value
Ejemplo n.º 2
0
    def __call__(self, *args, **kwargs):
        """
        An automatic guide with the same ``*args, **kwargs`` as the base ``model``.

        :return: A dict mapping sample site name to sampled value.
        :rtype: dict
        """
        sample_latent_fn = self._sample_latent
        if self.prototype_trace is None:
            # run model to inspect the model structure
            params = self.setup(*args, **kwargs)
            sample_latent_fn = substitute(sample_latent_fn, params)

        base_dist = kwargs.pop('base_dist', None)
        if base_dist is None:
            base_dist = _Normal(np.zeros(self.latent_size), 1.)
        latent = sample_latent_fn(base_dist, *args, **kwargs)

        # unpack continuous latent samples
        result = {}

        for name, unconstrained_value in self.unpack_latent(latent).items():
            transform = self._inv_transforms[name]
            site = self.prototype_trace[name]
            value = transform(unconstrained_value)
            log_density = -transform.log_abs_det_jacobian(
                unconstrained_value, value)
            if site['intermediates']:
                event_ndim = len(site['fn'].base_dist.event_shape)
            else:
                event_ndim = len(site['fn'].event_shape)
            log_density = sum_rightmost(
                log_density,
                np.ndim(log_density) - np.ndim(value) + event_ndim)
            delta_dist = dist.Delta(value,
                                    log_density=log_density,
                                    event_ndim=event_ndim)
            result[name] = sample(name, delta_dist)

        return result
Ejemplo n.º 3
0
 def log_prob(self, value):
     log_prob = self.base_dist.log_prob(value)
     return sum_rightmost(log_prob, self.reinterpreted_batch_ndims)
Ejemplo n.º 4
0
 def log_abs_det_jacobian(self, x, y, intermediates=None):
     result = self.base_transform.log_abs_det_jacobian(x, y, intermediates=intermediates)
     if jnp.ndim(result) < self.reinterpreted_batch_ndims:
         expected = self.domain.event_dim
         raise ValueError(f"Expected x.dim() >= {expected} but got {jnp.ndim(x)}")
     return sum_rightmost(result, self.reinterpreted_batch_ndims)
Ejemplo n.º 5
0
 def log_prob(self, value):
     log_prob = jnp.log(value == self.v)
     log_prob = sum_rightmost(log_prob, len(self.event_shape))
     return log_prob + self.log_density
Ejemplo n.º 6
0
 def log_abs_det_jacobian(self, x, y, intermediates=None):
     return sum_rightmost(
         np.broadcast_to(np.log(np.abs(self.scale)), np.shape(x)),
         self.event_dim)
Ejemplo n.º 7
0
 def log_prob(self, x):
     if self._validate_args:
         self._validate_sample(x)
     log_prob = np.log(x == self.value)
     log_prob = sum_rightmost(log_prob, len(self.event_shape))
     return log_prob + self.log_density