예제 #1
0
파일: structured.py 프로젝트: pyro-ppl/pyro
    def apply(self, msg):
        name = msg["name"]
        fn = msg["fn"]
        value = msg["value"]
        is_observed = msg["is_observed"]
        if name not in self.guide.prototype_trace.nodes:
            return {"fn": fn, "value": value, "is_observed": is_observed}
        if is_observed:
            raise NotImplementedError(
                f"At pyro.sample({repr(name)},...), "
                "StructuredReparam does not support observe statements")

        if name not in self.deltas:  # On first sample site.
            with ExitStack() as stack:
                for plate in self.guide.plates.values():
                    stack.enter_context(
                        block_plate(dim=plate.dim, strict=False))
                self.deltas = self.guide.get_deltas()
        new_fn = self.deltas.pop(name)
        value = new_fn.v

        if poutine.get_mask() is not False:
            log_density = new_fn.log_density + fn.log_prob(value)
            new_fn = dist.Delta(value, log_density, new_fn.event_dim)
        return {"fn": new_fn, "value": value, "is_observed": True}
예제 #2
0
파일: neutra.py 프로젝트: pyro-ppl/pyro
    def apply(self, msg):
        name = msg["name"]
        fn = msg["fn"]
        value = msg["value"]
        is_observed = msg["is_observed"]
        if name not in self.guide.prototype_trace.nodes:
            return {"fn": fn, "value": value, "is_observed": is_observed}
        if is_observed:
            raise NotImplementedError(
                f"At pyro.sample({repr(name)},...), "
                "NeuTraReparam does not support observe statements.")

        log_density = 0.0
        compute_density = poutine.get_mask() is not False
        if name not in self.x_unconstrained:  # On first sample site.
            # Sample a shared latent.
            try:
                self.transform = self.guide.get_transform()
            except (NotImplementedError, TypeError) as e:
                raise ValueError(
                    "NeuTraReparam only supports guides that implement "
                    "`get_transform` method that does not depend on the "
                    "model's `*args, **kwargs`") from e

            with ExitStack() as stack:
                for plate in self.guide.plates.values():
                    stack.enter_context(
                        block_plate(dim=plate.dim, strict=False))
                z_unconstrained = pyro.sample(
                    f"{name}_shared_latent",
                    self.guide.get_base_dist().mask(False))

            # Differentiably transform.
            x_unconstrained = self.transform(z_unconstrained)
            if compute_density:
                log_density = self.transform.log_abs_det_jacobian(
                    z_unconstrained, x_unconstrained)
            self.x_unconstrained = {
                site["name"]: (site, unconstrained_value)
                for site, unconstrained_value in self.guide._unpack_latent(
                    x_unconstrained)
            }

        # Extract a single site's value from the shared latent.
        site, unconstrained_value = self.x_unconstrained.pop(name)
        transform = biject_to(fn.support)
        value = transform(unconstrained_value)
        if compute_density:
            logdet = transform.log_abs_det_jacobian(unconstrained_value, value)
            logdet = sum_rightmost(logdet,
                                   logdet.dim() - value.dim() + fn.event_dim)
            log_density = log_density + fn.log_prob(value) + logdet
        new_fn = dist.Delta(value, log_density, event_dim=fn.event_dim)
        return {"fn": new_fn, "value": value, "is_observed": True}
예제 #3
0
    def __call__(self, name, fn, obs):
        assert obs is None, "TransformReparam does not support observe statements"
        event_dim = fn.event_dim
        transform = self.transform
        with ExitStack() as stack:
            shift = max(0, transform.event_dim - event_dim)
            if shift:
                if not self.experimental_allow_batch:
                    raise ValueError(
                        "Cannot transform along batch dimension; try either"
                        "converting a batch dimension to an event dimension, or "
                        "setting experimental_allow_batch=True.")

                # Reshape and mute plates using block_plate.
                from pyro.contrib.forecast.util import (
                    reshape_batch,
                    reshape_transform_batch,
                )
                old_shape = fn.batch_shape
                new_shape = old_shape[:-shift] + (
                    1, ) * shift + old_shape[-shift:]
                fn = reshape_batch(fn, new_shape).to_event(shift)
                transform = reshape_transform_batch(transform,
                                                    old_shape + fn.event_shape,
                                                    new_shape + fn.event_shape)
                for dim in range(-shift, 0):
                    stack.enter_context(block_plate(dim=dim, strict=False))

            # Draw noise from the base distribution.
            transform = ComposeTransform(
                [biject_to(fn.support).inv.with_cache(), self.transform])
            x_trans = pyro.sample("{}_{}".format(name, self.suffix),
                                  dist.TransformedDistribution(fn, transform))

        # Differentiably transform.
        x = transform.inv(x_trans)  # should be free due to transform cache
        if shift:
            x = x.reshape(x.shape[:-2 * shift - event_dim] +
                          x.shape[-shift - event_dim:])

        # Simulate a pyro.deterministic() site.
        new_fn = dist.Delta(x, event_dim=event_dim)
        return new_fn, x
예제 #4
0
    def apply(self, msg):
        name = msg["name"]
        fn = msg["fn"]
        value = msg["value"]
        is_observed = msg["is_observed"]

        event_dim = fn.event_dim
        transform = self.transform
        with ExitStack() as stack:
            shift = max(0, transform.event_dim - event_dim)
            if shift:
                if not self.experimental_allow_batch:
                    raise ValueError(
                        "Cannot transform along batch dimension; try either"
                        "converting a batch dimension to an event dimension, or "
                        "setting experimental_allow_batch=True.")

                # Reshape and mute plates using block_plate.
                from pyro.contrib.forecast.util import (
                    reshape_batch,
                    reshape_transform_batch,
                )

                old_shape = fn.batch_shape
                new_shape = old_shape[:-shift] + (
                    1, ) * shift + old_shape[-shift:]
                fn = reshape_batch(fn, new_shape).to_event(shift)
                transform = reshape_transform_batch(transform,
                                                    old_shape + fn.event_shape,
                                                    new_shape + fn.event_shape)
                if value is not None:
                    value = value.reshape(value.shape[:-shift - event_dim] +
                                          (1, ) * shift +
                                          value.shape[-shift - event_dim:])
                for dim in range(-shift, 0):
                    stack.enter_context(block_plate(dim=dim, strict=False))

            # Differentiably invert transform.
            transform = ComposeTransform(
                [biject_to(fn.support).inv.with_cache(), self.transform])
            value_trans = None
            if value is not None:
                value_trans = transform(value)

            # Draw noise from the base distribution.
            value_trans = pyro.sample(
                f"{name}_{self.suffix}",
                dist.TransformedDistribution(fn, transform),
                obs=value_trans,
                infer={"is_observed": is_observed},
            )

        # Differentiably transform. This should be free due to transform cache.
        if value is None:
            value = transform.inv(value_trans)
        if shift:
            value = value.reshape(value.shape[:-2 * shift - event_dim] +
                                  value.shape[-shift - event_dim:])

        # Simulate a pyro.deterministic() site.
        new_fn = dist.Delta(value, event_dim=event_dim)
        return {"fn": new_fn, "value": value, "is_observed": True}