Esempio n. 1
0
 def model():
     x = pyro.sample("x", dist.Normal(0, 1))
     pyro.sample("y", dist.Normal(x, 1), obs=torch.tensor(0.0))
     called.add("model-always")
     if poutine.get_mask() is not False:
         called.add("model-sometimes")
         pyro.factor("f", x + 1)
Esempio n. 2
0
    def _transform_values(
        self,
        aux_values: Dict[str, torch.Tensor],
    ) -> Tuple[Dict[str, torch.Tensor], Union[float, torch.Tensor]]:
        # Learnably transform auxiliary values to user-facing values.
        values = {}
        log_densities = defaultdict(float)
        compute_density = am_i_wrapped() and poutine.get_mask() is not False
        for name, site in self._factors.items():
            if site["is_observed"]:
                continue
            loc = deep_getattr(self.locs, name)
            scale = deep_getattr(self.scales, name)
            unconstrained = aux_values[name] * scale + loc

            # Transform to constrained space.
            transform = biject_to(site["fn"].support)
            values[name] = transform(unconstrained)
            if compute_density:
                assert transform.codomain.event_dim == site["fn"].event_dim
                log_densities[name] = transform.inv.log_abs_det_jacobian(
                    values[name], unconstrained
                ) - scale.log().reshape(site["fn"].batch_shape + (-1,)).sum(-1)

        return values, log_densities
Esempio n. 3
0
    def __call__(self, name, fn, obs):
        assert fn.event_dim >= self.event_dim
        assert obs is None, "SplitReparam does not support observe statements"

        # Draw independent parts.
        dim = fn.event_dim - self.event_dim
        left_shape = fn.event_shape[:dim]
        right_shape = fn.event_shape[1 + dim:]
        parts = []
        for i, size in enumerate(self.sections):
            event_shape = left_shape + (size, ) + right_shape
            parts.append(
                pyro.sample(
                    "{}_split_{}".format(name, i),
                    dist.ImproperUniform(fn.support, fn.batch_shape,
                                         event_shape)))
        value = torch.cat(parts, dim=-self.event_dim)

        # Combine parts.
        if poutine.get_mask() is False:
            log_density = 0.0
        else:
            log_density = fn.log_prob(value)
        new_fn = dist.Delta(value,
                            event_dim=fn.event_dim,
                            log_density=log_density)
        return new_fn, value
Esempio n. 4
0
    def apply(self, msg):
        name = msg["name"]
        fn = msg["fn"]
        value = msg["value"]
        is_observed = msg["is_observed"]
        if name not in self.guide.prototype_trace.nodes:
            return {"fn": fn, "value": value, "is_observed": is_observed}
        if is_observed:
            raise NotImplementedError(
                f"At pyro.sample({repr(name)},...), "
                "StructuredReparam does not support observe statements")

        if name not in self.deltas:  # On first sample site.
            with ExitStack() as stack:
                for plate in self.guide.plates.values():
                    stack.enter_context(
                        block_plate(dim=plate.dim, strict=False))
                self.deltas = self.guide.get_deltas()
        new_fn = self.deltas.pop(name)
        value = new_fn.v

        if poutine.get_mask() is not False:
            log_density = new_fn.log_density + fn.log_prob(value)
            new_fn = dist.Delta(value, log_density, new_fn.event_dim)
        return {"fn": new_fn, "value": value, "is_observed": True}
Esempio n. 5
0
    def _sample_aux_values(self, *, temperature: float) -> Dict[str, torch.Tensor]:
        funsor = _import_funsor()

        # Convert torch to funsor.
        particle_plates = frozenset(get_plates())
        plate_to_dim = self._funsor_plate_to_dim.copy()
        plate_to_dim.update({f.name: f.dim for f in particle_plates})
        factors = {}
        for d, inputs in self._funsor_factor_inputs.items():
            batch_shape = torch.Size(
                p.size for p in sorted(self._plates[d], key=lambda p: p.dim)
            )
            white_vec = deep_getattr(self.white_vecs, d)
            prec_sqrt = deep_getattr(self.prec_sqrts, d)
            factors[d] = funsor.gaussian.Gaussian(
                white_vec=white_vec.reshape(batch_shape + white_vec.shape[-1:]),
                prec_sqrt=prec_sqrt.reshape(batch_shape + prec_sqrt.shape[-2:]),
                inputs=inputs,
            )

        # Perform Gaussian tensor variable elimination.
        if temperature == 1:
            samples, log_prob = _try_possibly_intractable(
                funsor.recipes.forward_filter_backward_rsample,
                factors=factors,
                eliminate=self._funsor_eliminate,
                plates=frozenset(plate_to_dim),
                sample_inputs={f.name: funsor.Bint[f.size] for f in particle_plates},
            )

        else:
            samples, log_prob = _try_possibly_intractable(
                funsor.recipes.forward_filter_backward_precondition,
                factors=factors,
                eliminate=self._funsor_eliminate,
                plates=frozenset(plate_to_dim),
            )

            # Substitute noise.
            sample_shape = torch.Size(f.size for f in particle_plates)
            noise = torch.randn(sample_shape + log_prob.inputs["aux"].shape)
            noise.mul_(temperature)
            aux = funsor.Tensor(noise)[tuple(f.name for f in particle_plates)]
            with funsor.interpretations.memoize():
                samples = {k: v(aux=aux) for k, v in samples.items()}
                log_prob = log_prob(aux=aux)

        # Convert funsor to torch.
        if am_i_wrapped() and poutine.get_mask() is not False:
            log_prob = funsor.to_data(log_prob, name_to_dim=plate_to_dim)
            pyro.factor(f"_{self._pyro_name}_latent", log_prob, has_rsample=True)
        samples = {
            k: funsor.to_data(v, name_to_dim=plate_to_dim) for k, v in samples.items()
        }
        return samples
Esempio n. 6
0
    def apply(self, msg):
        name = msg["name"]
        fn = msg["fn"]
        value = msg["value"]
        is_observed = msg["is_observed"]

        # Compute a guide distribution, either static or dependent.
        guide_dist = self.guide
        if not isinstance(guide_dist, dist.Distribution):
            args, kwargs = self.args_kwargs
            guide_dist = guide_dist(*args, **kwargs)
        assert isinstance(guide_dist, dist.Distribution)

        # Draw a sample from the updated distribution.
        fn, log_normalizer = fn.conjugate_update(guide_dist)
        assert isinstance(guide_dist, dist.Distribution)
        if not fn.has_rsample:
            # Note supporting non-reparameterized sites would require more delicate
            # handling of traced sites than the crude _do_not_trace flag below.
            raise NotImplementedError(
                "ConjugateReparam inference supports only reparameterized "
                "distributions, but got {}".format(type(fn))
            )
        value = pyro.sample(
            f"{name}_updated",
            fn,
            obs=value,
            infer={
                "is_observed": is_observed,
                "is_auxiliary": True,
                "_do_not_trace": True,
            },
        )

        # Compute importance weight. Let p(z) be the original fn, q(z|x) be
        # the guide, and u(z) be the conjugate_updated distribution. Then
        #   normalizer = p(z) q(z|x) / u(z).
        # Since we've sampled from u(z) instead of p(z), we
        # need an importance weight
        #   p(z) / u(z) = normalizer / q(z|x)                          (Eqn 1)
        # Note that q(z|x) is often approximate; in the exact case
        #   q(z|x) = p(x|z) / integral p(x|z) dz
        # so this site and the downstream likelihood site will have combined density
        #   (p(z) / u(z)) p(x|z) = (normalizer / q(z|x)) p(x|z)
        #                        = normalizer integral p(x|z) dz
        # Hence in the exact case, downstream probability does not depend on the sampled z,
        # permitting this reparameterizer to be used in HMC.
        if poutine.get_mask() is False:
            log_density = 0.0
        else:
            log_density = log_normalizer - guide_dist.log_prob(value)  # By Eqn 1.

        # Return an importance-weighted point estimate.
        new_fn = dist.Delta(value, log_density=log_density, event_dim=fn.event_dim)
        return {"fn": new_fn, "value": value, "is_observed": True}
Esempio n. 7
0
    def forward(self, *args, **kwargs):
        """
        An automatic guide with the same ``*args, **kwargs`` as the base ``model``.

        .. note:: This method is used internally by :class:`~torch.nn.Module`.
            Users should instead use :meth:`~torch.nn.Module.__call__`.

        :return: A dict mapping sample site name to sampled value.
        :rtype: dict
        """
        # if we've never run the model before, do so now so we can inspect the model structure
        if self.prototype_trace is None:
            self._setup_prototype(*args, **kwargs)

        plates = self._create_plates(*args, **kwargs)
        result = {}
        for name, site in self.prototype_trace.iter_stochastic_nodes():
            transform = biject_to(site["fn"].support)

            with ExitStack() as stack:
                for frame in site["cond_indep_stack"]:
                    if frame.vectorized:
                        stack.enter_context(plates[frame.name])

                site_loc, site_scale = self._get_loc_and_scale(name)
                unconstrained_latent = pyro.sample(
                    name + "_unconstrained",
                    dist.Normal(
                        site_loc,
                        site_scale,
                    ).to_event(self._event_dims[name]),
                    infer={"is_auxiliary": True},
                )

                value = transform(unconstrained_latent)
                if poutine.get_mask() is False:
                    log_density = 0.0
                else:
                    log_density = transform.inv.log_abs_det_jacobian(
                        value,
                        unconstrained_latent,
                    )
                    log_density = sum_rightmost(
                        log_density,
                        log_density.dim() - value.dim() + site["fn"].event_dim,
                    )
                delta_dist = dist.Delta(
                    value,
                    log_density=log_density,
                    event_dim=site["fn"].event_dim,
                )

                result[name] = pyro.sample(name, delta_dist)

        return result
Esempio n. 8
0
    def apply(self, msg):
        name = msg["name"]
        fn = msg["fn"]
        value = msg["value"]
        is_observed = msg["is_observed"]
        if name not in self.guide.prototype_trace.nodes:
            return {"fn": fn, "value": value, "is_observed": is_observed}
        if is_observed:
            raise NotImplementedError(
                f"At pyro.sample({repr(name)},...), "
                "NeuTraReparam does not support observe statements.")

        log_density = 0.0
        compute_density = poutine.get_mask() is not False
        if name not in self.x_unconstrained:  # On first sample site.
            # Sample a shared latent.
            try:
                self.transform = self.guide.get_transform()
            except (NotImplementedError, TypeError) as e:
                raise ValueError(
                    "NeuTraReparam only supports guides that implement "
                    "`get_transform` method that does not depend on the "
                    "model's `*args, **kwargs`") from e

            with ExitStack() as stack:
                for plate in self.guide.plates.values():
                    stack.enter_context(
                        block_plate(dim=plate.dim, strict=False))
                z_unconstrained = pyro.sample(
                    f"{name}_shared_latent",
                    self.guide.get_base_dist().mask(False))

            # Differentiably transform.
            x_unconstrained = self.transform(z_unconstrained)
            if compute_density:
                log_density = self.transform.log_abs_det_jacobian(
                    z_unconstrained, x_unconstrained)
            self.x_unconstrained = {
                site["name"]: (site, unconstrained_value)
                for site, unconstrained_value in self.guide._unpack_latent(
                    x_unconstrained)
            }

        # Extract a single site's value from the shared latent.
        site, unconstrained_value = self.x_unconstrained.pop(name)
        transform = biject_to(fn.support)
        value = transform(unconstrained_value)
        if compute_density:
            logdet = transform.log_abs_det_jacobian(unconstrained_value, value)
            logdet = sum_rightmost(logdet,
                                   logdet.dim() - value.dim() + fn.event_dim)
            log_density = log_density + fn.log_prob(value) + logdet
        new_fn = dist.Delta(value, log_density, event_dim=fn.event_dim)
        return {"fn": new_fn, "value": value, "is_observed": True}
Esempio n. 9
0
    def forward(self, *args, **kwargs):
        """
        An automatic guide with the same ``*args, **kwargs`` as the base ``model``.

        .. note:: This method is used internally by :class:`~torch.nn.Module`.
            Users should instead use :meth:`~torch.nn.Module.__call__`.

        :return: A dict mapping sample site name to sampled value.
        :rtype: dict
        """
        # if we've never run the model before, do so now so we can inspect the model structure
        if self.prototype_trace is None:
            self._setup_prototype(*args, **kwargs)

        latent = self.sample_latent(*args, **kwargs)
        plates = self._create_plates(*args, **kwargs)

        # unpack continuous latent samples
        result = {}
        for site, unconstrained_value in self._unpack_latent(latent):
            name = site["name"]
            transform = biject_to(site["fn"].support)
            value = transform(unconstrained_value)
            if poutine.get_mask() is False:
                log_density = 0.0
            else:
                log_density = transform.inv.log_abs_det_jacobian(
                    value,
                    unconstrained_value,
                )
                log_density = sum_rightmost(
                    log_density,
                    log_density.dim() - value.dim() + site["fn"].event_dim,
                )
            delta_dist = dist.Delta(
                value,
                log_density=log_density,
                event_dim=site["fn"].event_dim,
            )

            with ExitStack() as stack:
                for frame in self._cond_indep_stacks[name]:
                    stack.enter_context(plates[frame.name])
                result[name] = pyro.sample(name, delta_dist)

        return result
Esempio n. 10
0
    def __call__(self, name, fn, obs):
        if name not in self.guide.prototype_trace.nodes:
            return fn, obs
        assert obs is None, "NeuTraReparam does not support observe statements"
        log_density = 0.0
        compute_density = (poutine.get_mask() is not False)
        if not self.x_unconstrained:  # On first sample site.
            # Sample a shared latent.
            try:
                self.transform = self.guide.get_transform()
            except (NotImplementedError, TypeError) as e:
                raise ValueError(
                    "NeuTraReparam only supports guides that implement "
                    "`get_transform` method that does not depend on the "
                    "model's `*args, **kwargs`") from e

            z_unconstrained = pyro.sample(
                "{}_shared_latent".format(name),
                self.guide.get_base_dist().mask(False))

            # Differentiably transform.
            x_unconstrained = self.transform(z_unconstrained)
            if compute_density:
                log_density = self.transform.log_abs_det_jacobian(
                    z_unconstrained, x_unconstrained)
            self.x_unconstrained = list(
                reversed(list(self.guide._unpack_latent(x_unconstrained))))

        # Extract a single site's value from the shared latent.
        site, unconstrained_value = self.x_unconstrained.pop()
        assert name == site["name"], "model structure changed"
        transform = biject_to(fn.support)
        value = transform(unconstrained_value)
        if compute_density:
            logdet = transform.log_abs_det_jacobian(unconstrained_value, value)
            logdet = sum_rightmost(logdet,
                                   logdet.dim() - value.dim() + fn.event_dim)
            log_density = log_density + fn.log_prob(value) + logdet
        new_fn = dist.Delta(value, log_density, event_dim=fn.event_dim)
        return new_fn, value
Esempio n. 11
0
    def apply(self, msg):
        name = msg["name"]
        fn = msg["fn"]
        value = msg["value"]
        is_observed = msg["is_observed"]
        assert fn.event_dim >= self.event_dim

        # Split value into parts.
        value_split = [None] * len(self.sections)
        if value is not None:
            value_split[:] = value.split(self.sections, -self.event_dim)

        # Draw independent parts.
        dim = fn.event_dim - self.event_dim
        left_shape = fn.event_shape[:dim]
        right_shape = fn.event_shape[1 + dim:]
        for i, size in enumerate(self.sections):
            event_shape = left_shape + (size, ) + right_shape
            value_split[i] = pyro.sample(
                f"{name}_split_{i}",
                dist.ImproperUniform(fn.support, fn.batch_shape, event_shape),
                obs=value_split[i],
                infer={"is_observed": is_observed},
            )

        # Combine parts into value.
        if value is None:
            value = torch.cat(value_split, dim=-self.event_dim)

        if poutine.get_mask() is False:
            log_density = 0.0
        else:
            log_density = fn.log_prob(value)
        new_fn = dist.Delta(value,
                            event_dim=fn.event_dim,
                            log_density=log_density)
        return {"fn": new_fn, "value": value, "is_observed": True}
Esempio n. 12
0
 def guide():
     x = pyro.sample("x", dist.Normal(0, 1))
     called.add("guide-always")
     if poutine.get_mask() is not False:
         called.add("guide-sometimes")
         pyro.factor("g", 2 - x)
Esempio n. 13
0
    def get_deltas(self, save_params=None):
        deltas = {}
        aux_values = {}
        compute_density = poutine.get_mask() is not False
        for name, site in self._sorted_sites:
            if save_params is not None and name not in save_params:
                continue

            # Sample zero-mean blockwise independent Delta/Normal/MVN.
            log_density = 0.0
            loc = deep_getattr(self.locs, name)
            zero = torch.zeros_like(loc)
            conditional = self.conditionals[name]
            if callable(conditional):
                aux_value = deep_getattr(self.conds, name)()
            elif conditional == "delta":
                aux_value = zero
            elif conditional == "normal":
                aux_value = pyro.sample(
                    name + "_aux",
                    dist.Normal(zero, 1).to_event(1),
                    infer={"is_auxiliary": True},
                )
                scale = deep_getattr(self.scales, name)
                aux_value = aux_value * scale
                if compute_density:
                    log_density = (-scale.log()).expand_as(aux_value)
            elif conditional == "mvn":
                # This overparametrizes by learning (scale,scale_tril),
                # enabling faster learning of the more-global scale parameter.
                aux_value = pyro.sample(
                    name + "_aux",
                    dist.Normal(zero, 1).to_event(1),
                    infer={"is_auxiliary": True},
                )
                scale = deep_getattr(self.scales, name)
                scale_tril = deep_getattr(self.scale_trils, name)
                aux_value = aux_value @ scale_tril.T * scale
                if compute_density:
                    log_density = (
                        -scale_tril.diagonal(dim1=-2, dim2=-1).log() -
                        scale.log()).expand_as(aux_value)
            else:
                raise ValueError(
                    f"Unsupported conditional type: {conditional}")

            # Accumulate upstream dependencies.
            # Note: by accumulating upstream dependencies before updating the
            # aux_values dict, we encode a block-sparse structure of the
            # precision matrix; if we had instead accumulated after updating
            # aux_values, we would encode a block-sparse structure of the
            # covariance matrix.
            # Note: these shear transforms have no effect on the Jacobian
            # determinant, and can therefore be excluded from the log_density
            # computation below, even for nonlinear dep().
            deps = deep_getattr(self.deps, name)
            for upstream in self.dependencies.get(name, {}):
                dep = deep_getattr(deps, upstream)
                aux_value = aux_value + dep(aux_values[upstream])
            aux_values[name] = aux_value

            # Shift by loc and reshape.
            batch_shape = torch.broadcast_shapes(aux_value.shape[:-1],
                                                 self._batch_shapes[name])
            unconstrained = (
                aux_value +
                loc).reshape(batch_shape +
                             self._unconstrained_event_shapes[name])
            if not is_identically_zero(log_density):
                log_density = log_density.reshape(batch_shape + (-1, )).sum(-1)

            # Transform to constrained space.
            transform = biject_to(site["fn"].support)
            value = transform(unconstrained)
            if compute_density and conditional != "delta":
                assert transform.codomain.event_dim == site["fn"].event_dim
                log_density = log_density + transform.inv.log_abs_det_jacobian(
                    value, unconstrained)

            # Create a reparametrized Delta distribution.
            deltas[name] = dist.Delta(value, log_density, site["fn"].event_dim)

        return deltas