Exemplo n.º 1
0
    def __call__(self, name, fn, obs):
        assert obs is None, "TransformReparam does not support observe statements"
        assert fn.event_dim >= self.transform.event_dim, (
            "Cannot transform along batch dimension; "
            "try converting a batch dimension to an event dimension")

        # Draw noise from the base distribution.
        transform = ComposeTransform(
            [_with_cache(biject_to(fn.support).inv), self.transform])
        x_trans = pyro.sample("{}_{}".format(name, self.suffix),
                              dist.TransformedDistribution(fn, transform))

        # Differentiably transform.
        x = transform.inv(x_trans)  # should be free due to transform cache

        # Simulate a pyro.deterministic() site.
        new_fn = dist.Delta(x, event_dim=fn.event_dim)
        return new_fn, x
Exemplo n.º 2
0
    def __call__(self, name, fn, obs):
        assert obs is None, "TransformReparam does not support observe statements"
        event_dim = fn.event_dim
        transform = self.transform
        with ExitStack() as stack:
            shift = max(0, transform.event_dim - event_dim)
            if shift:
                if not self.experimental_allow_batch:
                    raise ValueError(
                        "Cannot transform along batch dimension; try either"
                        "converting a batch dimension to an event dimension, or "
                        "setting experimental_allow_batch=True.")

                # Reshape and mute plates using block_plate.
                from pyro.contrib.forecast.util import (
                    reshape_batch,
                    reshape_transform_batch,
                )
                old_shape = fn.batch_shape
                new_shape = old_shape[:-shift] + (
                    1, ) * shift + old_shape[-shift:]
                fn = reshape_batch(fn, new_shape).to_event(shift)
                transform = reshape_transform_batch(transform,
                                                    old_shape + fn.event_shape,
                                                    new_shape + fn.event_shape)
                for dim in range(-shift, 0):
                    stack.enter_context(block_plate(dim=dim, strict=False))

            # Draw noise from the base distribution.
            transform = ComposeTransform(
                [biject_to(fn.support).inv.with_cache(), self.transform])
            x_trans = pyro.sample("{}_{}".format(name, self.suffix),
                                  dist.TransformedDistribution(fn, transform))

        # Differentiably transform.
        x = transform.inv(x_trans)  # should be free due to transform cache
        if shift:
            x = x.reshape(x.shape[:-2 * shift - event_dim] +
                          x.shape[-shift - event_dim:])

        # Simulate a pyro.deterministic() site.
        new_fn = dist.Delta(x, event_dim=event_dim)
        return new_fn, x
Exemplo n.º 3
0
    def apply(self, msg):
        name = msg["name"]
        fn = msg["fn"]
        value = msg["value"]
        is_observed = msg["is_observed"]

        event_dim = fn.event_dim
        transform = self.transform
        with ExitStack() as stack:
            shift = max(0, transform.event_dim - event_dim)
            if shift:
                if not self.experimental_allow_batch:
                    raise ValueError(
                        "Cannot transform along batch dimension; try either"
                        "converting a batch dimension to an event dimension, or "
                        "setting experimental_allow_batch=True.")

                # Reshape and mute plates using block_plate.
                from pyro.contrib.forecast.util import (
                    reshape_batch,
                    reshape_transform_batch,
                )

                old_shape = fn.batch_shape
                new_shape = old_shape[:-shift] + (
                    1, ) * shift + old_shape[-shift:]
                fn = reshape_batch(fn, new_shape).to_event(shift)
                transform = reshape_transform_batch(transform,
                                                    old_shape + fn.event_shape,
                                                    new_shape + fn.event_shape)
                if value is not None:
                    value = value.reshape(value.shape[:-shift - event_dim] +
                                          (1, ) * shift +
                                          value.shape[-shift - event_dim:])
                for dim in range(-shift, 0):
                    stack.enter_context(block_plate(dim=dim, strict=False))

            # Differentiably invert transform.
            transform = ComposeTransform(
                [biject_to(fn.support).inv.with_cache(), self.transform])
            value_trans = None
            if value is not None:
                value_trans = transform(value)

            # Draw noise from the base distribution.
            value_trans = pyro.sample(
                f"{name}_{self.suffix}",
                dist.TransformedDistribution(fn, transform),
                obs=value_trans,
                infer={"is_observed": is_observed},
            )

        # Differentiably transform. This should be free due to transform cache.
        if value is None:
            value = transform.inv(value_trans)
        if shift:
            value = value.reshape(value.shape[:-2 * shift - event_dim] +
                                  value.shape[-shift - event_dim:])

        # Simulate a pyro.deterministic() site.
        new_fn = dist.Delta(value, event_dim=event_dim)
        return {"fn": new_fn, "value": value, "is_observed": True}