def __init__(self, base_distribution, transforms, validate_args=None):
        if isinstance(transforms, Transform):
            self.transforms = [transforms, ]
        elif isinstance(transforms, list):
            if not all(isinstance(t, Transform) for t in transforms):
                raise ValueError("transforms must be a Transform or a list of Transforms")
            self.transforms = transforms
        else:
            raise ValueError("transforms must be a Transform or list, but was {}".format(transforms))

        # Reshape base_distribution according to transforms.
        base_shape = base_distribution.batch_shape + base_distribution.event_shape
        base_event_dim = len(base_distribution.event_shape)
        transform = ComposeTransform(self.transforms)
        domain_event_dim = transform.domain.event_dim
        if len(base_shape) < domain_event_dim:
            raise ValueError("base_distribution needs to have shape with size at least {}, but got {}."
                             .format(domain_event_dim, base_shape))
        shape = transform.forward_shape(base_shape)
        expanded_base_shape = transform.inverse_shape(shape)
        if base_shape != expanded_base_shape:
            base_batch_shape = expanded_base_shape[:len(expanded_base_shape) - base_event_dim]
            base_distribution = base_distribution.expand(base_batch_shape)
        reinterpreted_batch_ndims = domain_event_dim - base_event_dim
        if reinterpreted_batch_ndims > 0:
            base_distribution = Independent(base_distribution, reinterpreted_batch_ndims)
        self.base_dist = base_distribution

        # Compute shapes.
        event_dim = transform.codomain.event_dim + max(base_event_dim - domain_event_dim, 0)
        assert len(shape) >= event_dim
        cut = len(shape) - event_dim
        batch_shape = shape[:cut]
        event_shape = shape[cut:]
        super(TransformedDistribution, self).__init__(batch_shape, event_shape, validate_args=validate_args)
Example #2
0
def get_transforms(cache_size):
    transforms = [
        AbsTransform(cache_size=cache_size),
        ExpTransform(cache_size=cache_size),
        PowerTransform(exponent=2,
                       cache_size=cache_size),
        PowerTransform(exponent=torch.tensor(5.).normal_(),
                       cache_size=cache_size),
        PowerTransform(exponent=torch.tensor(5.).normal_(),
                       cache_size=cache_size),
        SigmoidTransform(cache_size=cache_size),
        TanhTransform(cache_size=cache_size),
        AffineTransform(0, 1, cache_size=cache_size),
        AffineTransform(1, -2, cache_size=cache_size),
        AffineTransform(torch.randn(5),
                        torch.randn(5),
                        cache_size=cache_size),
        AffineTransform(torch.randn(4, 5),
                        torch.randn(4, 5),
                        cache_size=cache_size),
        SoftmaxTransform(cache_size=cache_size),
        SoftplusTransform(cache_size=cache_size),
        StickBreakingTransform(cache_size=cache_size),
        LowerCholeskyTransform(cache_size=cache_size),
        CorrCholeskyTransform(cache_size=cache_size),
        ComposeTransform([
            AffineTransform(torch.randn(4, 5),
                            torch.randn(4, 5),
                            cache_size=cache_size),
        ]),
        ComposeTransform([
            AffineTransform(torch.randn(4, 5),
                            torch.randn(4, 5),
                            cache_size=cache_size),
            ExpTransform(cache_size=cache_size),
        ]),
        ComposeTransform([
            AffineTransform(0, 1, cache_size=cache_size),
            AffineTransform(torch.randn(4, 5),
                            torch.randn(4, 5),
                            cache_size=cache_size),
            AffineTransform(1, -2, cache_size=cache_size),
            AffineTransform(torch.randn(4, 5),
                            torch.randn(4, 5),
                            cache_size=cache_size),
        ]),
        ReshapeTransform((4, 5), (2, 5, 2)),
        IndependentTransform(
            AffineTransform(torch.randn(5),
                            torch.randn(5),
                            cache_size=cache_size),
            1),
        CumulativeDistributionTransform(Normal(0, 1)),
    ]
    transforms += [t.inv for t in transforms]
    return transforms
Example #3
0
def test_compose_transform_shapes():
    transform0 = ExpTransform()
    transform1 = SoftmaxTransform()
    transform2 = LowerCholeskyTransform()

    assert transform0.event_dim == 0
    assert transform1.event_dim == 1
    assert transform2.event_dim == 2
    assert ComposeTransform([transform0, transform1]).event_dim == 1
    assert ComposeTransform([transform0, transform2]).event_dim == 2
    assert ComposeTransform([transform1, transform2]).event_dim == 2
Example #4
0
def get_transforms(cache_size):
    transforms = [
        AbsTransform(cache_size=cache_size),
        ExpTransform(cache_size=cache_size),
        PowerTransform(exponent=2,
                       cache_size=cache_size),
        PowerTransform(exponent=torch.tensor(5.).normal_(),
                       cache_size=cache_size),
        SigmoidTransform(cache_size=cache_size),
        TanhTransform(cache_size=cache_size),
        AffineTransform(0, 1, cache_size=cache_size),
        AffineTransform(1, -2, cache_size=cache_size),
        AffineTransform(torch.randn(5),
                        torch.randn(5),
                        cache_size=cache_size),
        AffineTransform(torch.randn(4, 5),
                        torch.randn(4, 5),
                        cache_size=cache_size),
        SoftmaxTransform(cache_size=cache_size),
        StickBreakingTransform(cache_size=cache_size),
        LowerCholeskyTransform(cache_size=cache_size),
        CorrCholeskyTransform(cache_size=cache_size),
        ComposeTransform([
            AffineTransform(torch.randn(4, 5),
                            torch.randn(4, 5),
                            cache_size=cache_size),
        ]),
        ComposeTransform([
            AffineTransform(torch.randn(4, 5),
                            torch.randn(4, 5),
                            cache_size=cache_size),
            ExpTransform(cache_size=cache_size),
        ]),
        ComposeTransform([
            AffineTransform(0, 1, cache_size=cache_size),
            AffineTransform(torch.randn(4, 5),
                            torch.randn(4, 5),
                            cache_size=cache_size),
            AffineTransform(1, -2, cache_size=cache_size),
            AffineTransform(torch.randn(4, 5),
                            torch.randn(4, 5),
                            cache_size=cache_size),
        ]),
    ]
    transforms += [t.inv for t in transforms]
    return transforms
Example #5
0
    def __call__(self, name, fn, obs):
        assert obs is None, "TransformReparam does not support observe statements"
        assert fn.event_dim >= self.transform.event_dim, (
            "Cannot transform along batch dimension; "
            "try converting a batch dimension to an event dimension")

        # Draw noise from the base distribution.
        transform = ComposeTransform(
            [_with_cache(biject_to(fn.support).inv), self.transform])
        x_trans = pyro.sample("{}_{}".format(name, self.suffix),
                              dist.TransformedDistribution(fn, transform))

        # Differentiably transform.
        x = transform.inv(x_trans)  # should be free due to transform cache

        # Simulate a pyro.deterministic() site.
        new_fn = dist.Delta(x, event_dim=fn.event_dim)
        return new_fn, x
Example #6
0
    def __call__(self, name, fn, obs):
        assert obs is None, "TransformReparam does not support observe statements"
        event_dim = fn.event_dim
        transform = self.transform
        with ExitStack() as stack:
            shift = max(0, transform.event_dim - event_dim)
            if shift:
                if not self.experimental_allow_batch:
                    raise ValueError(
                        "Cannot transform along batch dimension; try either"
                        "converting a batch dimension to an event dimension, or "
                        "setting experimental_allow_batch=True.")

                # Reshape and mute plates using block_plate.
                from pyro.contrib.forecast.util import (
                    reshape_batch,
                    reshape_transform_batch,
                )
                old_shape = fn.batch_shape
                new_shape = old_shape[:-shift] + (
                    1, ) * shift + old_shape[-shift:]
                fn = reshape_batch(fn, new_shape).to_event(shift)
                transform = reshape_transform_batch(transform,
                                                    old_shape + fn.event_shape,
                                                    new_shape + fn.event_shape)
                for dim in range(-shift, 0):
                    stack.enter_context(block_plate(dim=dim, strict=False))

            # Draw noise from the base distribution.
            transform = ComposeTransform(
                [biject_to(fn.support).inv.with_cache(), self.transform])
            x_trans = pyro.sample("{}_{}".format(name, self.suffix),
                                  dist.TransformedDistribution(fn, transform))

        # Differentiably transform.
        x = transform.inv(x_trans)  # should be free due to transform cache
        if shift:
            x = x.reshape(x.shape[:-2 * shift - event_dim] +
                          x.shape[-shift - event_dim:])

        # Simulate a pyro.deterministic() site.
        new_fn = dist.Delta(x, event_dim=event_dim)
        return new_fn, x
Example #7
0
    def _validate_transform(t):
        """
        Ensure that the provided transform can be evaluated.

        :param t: The transform to be validated
        
        :return: (torch.distributions.Transform) a valid transform.
        """

        return ComposeTransform([]) if t is None else t
    def __init__(self, w, p, temperature=0.1, validate_args=None):
        relaxed_bernoulli = RelaxedBernoulli(temperature, p)
        affine_transform = AffineTransform(0, w)
        one_minus_p = AffineTransform(1, -1)
        super(BernoulliDropoutDistribution,
              self).__init__(relaxed_bernoulli,
                             ComposeTransform([one_minus_p, affine_transform]),
                             validate_args)

        self.relaxed_bernoulli = relaxed_bernoulli
        self.affine_transform = affine_transform
Example #9
0
def test_compose_reshape(batch_shape):
    transforms = [ReshapeTransform((), ()),
                  ReshapeTransform((2,), (1, 2)),
                  ReshapeTransform((3, 1, 2), (6,)),
                  ReshapeTransform((6,), (2, 3))]
    transform = ComposeTransform(transforms)
    assert transform.codomain.event_dim == 2
    assert transform.domain.event_dim == 2
    data = torch.randn(batch_shape + (3, 2))
    assert transform(data).shape == batch_shape + (2, 3)

    dist = TransformedDistribution(Normal(data, 1), transforms)
    assert dist.batch_shape == batch_shape
    assert dist.event_shape == (2, 3)
    assert dist.support.event_dim == 2
Example #10
0
def test_compose_affine(event_dims):
    transforms = [AffineTransform(torch.zeros((1,) * e), 1, event_dim=e) for e in event_dims]
    transform = ComposeTransform(transforms)
    assert transform.codomain.event_dim == max(event_dims)
    assert transform.domain.event_dim == max(event_dims)

    base_dist = Normal(0, 1)
    if transform.domain.event_dim:
        base_dist = base_dist.expand((1,) * transform.domain.event_dim)
    dist = TransformedDistribution(base_dist, transform.parts)
    assert dist.support.event_dim == max(event_dims)

    base_dist = Dirichlet(torch.ones(5))
    if transform.domain.event_dim > 1:
        base_dist = base_dist.expand((1,) * (transform.domain.event_dim - 1))
    dist = TransformedDistribution(base_dist, transforms)
    assert dist.support.event_dim == max(1, max(event_dims))
Example #11
0
def reshape_transform(transform, shape):
    # Needed to squash batch dims for testing jacobian
    if isinstance(transform, AffineTransform):
        if isinstance(transform.loc, Number):
            return transform
        try:
            return AffineTransform(transform.loc.expand(shape), transform.scale.expand(shape), cache_size=transform._cache_size)
        except RuntimeError:
            return AffineTransform(transform.loc.reshape(shape), transform.scale.reshape(shape), cache_size=transform._cache_size)
    if isinstance(transform, ComposeTransform):
        reshaped_parts = []
        for p in transform.parts:
            reshaped_parts.append(reshape_transform(p, shape))
        return ComposeTransform(reshaped_parts, cache_size=transform._cache_size)
    if isinstance(transform.inv, AffineTransform):
        return reshape_transform(transform.inv, shape).inv
    if isinstance(transform.inv, ComposeTransform):
        return reshape_transform(transform.inv, shape).inv
    return transform
Example #12
0
def test_transformed_distribution(base_batch_dim, base_event_dim,
                                  transform_dim, num_transforms, sample_shape):
    shape = torch.Size([2, 3, 4, 5])
    base_dist = Normal(0, 1)
    base_dist = base_dist.expand(shape[4 - base_batch_dim - base_event_dim:])
    if base_event_dim:
        base_dist = Independent(base_dist, base_event_dim)
    transforms = [
        AffineTransform(torch.zeros(shape[4 - transform_dim:]), 1),
        ReshapeTransform((4, 5), (20, )),
        ReshapeTransform((3, 20), (6, 10))
    ]
    transforms = transforms[:num_transforms]
    transform = ComposeTransform(transforms)

    # Check validation in .__init__().
    if base_batch_dim + base_event_dim < transform.domain.event_dim:
        with pytest.raises(ValueError):
            TransformedDistribution(base_dist, transforms)
        return
    d = TransformedDistribution(base_dist, transforms)

    # Check sampling is sufficiently expanded.
    x = d.sample(sample_shape)
    assert x.shape == sample_shape + d.batch_shape + d.event_shape
    num_unique = len(set(x.reshape(-1).tolist()))
    assert num_unique >= 0.9 * x.numel()

    # Check log_prob shape on full samples.
    log_prob = d.log_prob(x)
    assert log_prob.shape == sample_shape + d.batch_shape

    # Check log_prob shape on partial samples.
    y = x
    while y.dim() > len(d.event_shape):
        y = y[0]
    log_prob = d.log_prob(y)
    assert log_prob.shape == d.batch_shape
Example #13
0
    def __init__(self, a, theta, alpha, beta):
        """
        The Amoroso distribution is a very flexible 4 parameter distribution which 
        contains many important exponential families as special cases. 

        *PDF*
        ```
        Amoroso(x | a, θ, α, β) = 1/gamma(α) * abs(β/θ) * ((x - a)/θ)**(α*β-1) * exp(-((x - a)/θ)**β)
        for:
            x, a, θ, α, β \in reals, α > 0
        support:
            x >= a if θ > 0
            x <= a if θ < 0
        ```
        """
        self.a, self.theta, self.alpha, self.beta = broadcast_all(
            a, theta, alpha, beta)

        base_dist = Gamma(self.alpha, 1.)
        transform = ComposeTransform([
            AffineTransform(-self.a / self.theta, 1 / self.theta),
            PowerTransform(self.beta),
        ]).inv
        super().__init__(base_dist, transform)
Example #14
0
def _transform_to_positive_ordered_vector(constraint):
    return ComposeTransform([OrderedTransform(), ExpTransform()])
Example #15
0
def _transform_to_corr_matrix(constraint):
    return ComposeTransform(
        [CorrLCholeskyTransform(), CorrMatrixCholeskyTransform().inv]
    )
Example #16
0
def _transform_to_positive_definite(constraint):
    return ComposeTransform([LowerCholeskyTransform(), CholeskyTransform().inv])
Example #17
0
    def apply(self, msg):
        name = msg["name"]
        fn = msg["fn"]
        value = msg["value"]
        is_observed = msg["is_observed"]

        event_dim = fn.event_dim
        transform = self.transform
        with ExitStack() as stack:
            shift = max(0, transform.event_dim - event_dim)
            if shift:
                if not self.experimental_allow_batch:
                    raise ValueError(
                        "Cannot transform along batch dimension; try either"
                        "converting a batch dimension to an event dimension, or "
                        "setting experimental_allow_batch=True.")

                # Reshape and mute plates using block_plate.
                from pyro.contrib.forecast.util import (
                    reshape_batch,
                    reshape_transform_batch,
                )

                old_shape = fn.batch_shape
                new_shape = old_shape[:-shift] + (
                    1, ) * shift + old_shape[-shift:]
                fn = reshape_batch(fn, new_shape).to_event(shift)
                transform = reshape_transform_batch(transform,
                                                    old_shape + fn.event_shape,
                                                    new_shape + fn.event_shape)
                if value is not None:
                    value = value.reshape(value.shape[:-shift - event_dim] +
                                          (1, ) * shift +
                                          value.shape[-shift - event_dim:])
                for dim in range(-shift, 0):
                    stack.enter_context(block_plate(dim=dim, strict=False))

            # Differentiably invert transform.
            transform = ComposeTransform(
                [biject_to(fn.support).inv.with_cache(), self.transform])
            value_trans = None
            if value is not None:
                value_trans = transform(value)

            # Draw noise from the base distribution.
            value_trans = pyro.sample(
                f"{name}_{self.suffix}",
                dist.TransformedDistribution(fn, transform),
                obs=value_trans,
                infer={"is_observed": is_observed},
            )

        # Differentiably transform. This should be free due to transform cache.
        if value is None:
            value = transform.inv(value_trans)
        if shift:
            value = value.reshape(value.shape[:-2 * shift - event_dim] +
                                  value.shape[-shift - event_dim:])

        # Simulate a pyro.deterministic() site.
        new_fn = dist.Delta(value, event_dim=event_dim)
        return {"fn": new_fn, "value": value, "is_observed": True}