Exemple #1
0
def test_reshape_batch(Dist, batch_shape, duration, dim):
    d = random_dist(Dist, batch_shape + (duration, dim))
    d = d.to_event(2 - d.event_dim)
    assert d.batch_shape == batch_shape
    assert d.event_shape == (duration, dim)

    actual = reshape_batch(d, batch_shape + (1, ))
    assert type(actual) is type(d)
    assert actual.batch_shape == batch_shape + (1, )
    assert actual.event_shape == (duration, dim)
Exemple #2
0
def test_reshape_transform_batch(transform, batch_shape, duration, dim):
    params = {p: torch.rand(batch_shape + (duration, dim))
              for p in UNIVARIATE_TRANSFORMS[transform]}
    t = transform(**params)
    d = random_dist(dist.LinearHMM, batch_shape + (duration, dim), transform=t)
    d = d.to_event(2 - d.event_dim)
    assert d.batch_shape == batch_shape
    assert d.event_shape == (duration, dim)

    actual = reshape_batch(d, batch_shape + (1,))
    assert type(actual) is type(d)
    assert actual.batch_shape == batch_shape + (1,)
    assert actual.event_shape == (duration, dim)

    # test if we have reshape transforms correctly
    assert actual.rsample().shape == actual.shape()
Exemple #3
0
    def __call__(self, name, fn, obs):
        assert obs is None, "TransformReparam does not support observe statements"
        event_dim = fn.event_dim
        transform = self.transform
        with ExitStack() as stack:
            shift = max(0, transform.event_dim - event_dim)
            if shift:
                if not self.experimental_allow_batch:
                    raise ValueError(
                        "Cannot transform along batch dimension; try either"
                        "converting a batch dimension to an event dimension, or "
                        "setting experimental_allow_batch=True.")

                # Reshape and mute plates using block_plate.
                from pyro.contrib.forecast.util import (
                    reshape_batch,
                    reshape_transform_batch,
                )
                old_shape = fn.batch_shape
                new_shape = old_shape[:-shift] + (
                    1, ) * shift + old_shape[-shift:]
                fn = reshape_batch(fn, new_shape).to_event(shift)
                transform = reshape_transform_batch(transform,
                                                    old_shape + fn.event_shape,
                                                    new_shape + fn.event_shape)
                for dim in range(-shift, 0):
                    stack.enter_context(block_plate(dim=dim, strict=False))

            # Draw noise from the base distribution.
            transform = ComposeTransform(
                [biject_to(fn.support).inv.with_cache(), self.transform])
            x_trans = pyro.sample("{}_{}".format(name, self.suffix),
                                  dist.TransformedDistribution(fn, transform))

        # Differentiably transform.
        x = transform.inv(x_trans)  # should be free due to transform cache
        if shift:
            x = x.reshape(x.shape[:-2 * shift - event_dim] +
                          x.shape[-shift - event_dim:])

        # Simulate a pyro.deterministic() site.
        new_fn = dist.Delta(x, event_dim=event_dim)
        return new_fn, x
Exemple #4
0
    def apply(self, msg):
        name = msg["name"]
        fn = msg["fn"]
        value = msg["value"]
        is_observed = msg["is_observed"]

        event_dim = fn.event_dim
        transform = self.transform
        with ExitStack() as stack:
            shift = max(0, transform.event_dim - event_dim)
            if shift:
                if not self.experimental_allow_batch:
                    raise ValueError(
                        "Cannot transform along batch dimension; try either"
                        "converting a batch dimension to an event dimension, or "
                        "setting experimental_allow_batch=True.")

                # Reshape and mute plates using block_plate.
                from pyro.contrib.forecast.util import (
                    reshape_batch,
                    reshape_transform_batch,
                )

                old_shape = fn.batch_shape
                new_shape = old_shape[:-shift] + (
                    1, ) * shift + old_shape[-shift:]
                fn = reshape_batch(fn, new_shape).to_event(shift)
                transform = reshape_transform_batch(transform,
                                                    old_shape + fn.event_shape,
                                                    new_shape + fn.event_shape)
                if value is not None:
                    value = value.reshape(value.shape[:-shift - event_dim] +
                                          (1, ) * shift +
                                          value.shape[-shift - event_dim:])
                for dim in range(-shift, 0):
                    stack.enter_context(block_plate(dim=dim, strict=False))

            # Differentiably invert transform.
            transform = ComposeTransform(
                [biject_to(fn.support).inv.with_cache(), self.transform])
            value_trans = None
            if value is not None:
                value_trans = transform(value)

            # Draw noise from the base distribution.
            value_trans = pyro.sample(
                f"{name}_{self.suffix}",
                dist.TransformedDistribution(fn, transform),
                obs=value_trans,
                infer={"is_observed": is_observed},
            )

        # Differentiably transform. This should be free due to transform cache.
        if value is None:
            value = transform.inv(value_trans)
        if shift:
            value = value.reshape(value.shape[:-2 * shift - event_dim] +
                                  value.shape[-shift - event_dim:])

        # Simulate a pyro.deterministic() site.
        new_fn = dist.Delta(value, event_dim=event_dim)
        return {"fn": new_fn, "value": value, "is_observed": True}