def __init__(self, base_distribution, transforms, validate_args=None): if isinstance(transforms, Transform): self.transforms = [transforms, ] elif isinstance(transforms, list): if not all(isinstance(t, Transform) for t in transforms): raise ValueError("transforms must be a Transform or a list of Transforms") self.transforms = transforms else: raise ValueError("transforms must be a Transform or list, but was {}".format(transforms)) # Reshape base_distribution according to transforms. base_shape = base_distribution.batch_shape + base_distribution.event_shape base_event_dim = len(base_distribution.event_shape) transform = ComposeTransform(self.transforms) domain_event_dim = transform.domain.event_dim if len(base_shape) < domain_event_dim: raise ValueError("base_distribution needs to have shape with size at least {}, but got {}." .format(domain_event_dim, base_shape)) shape = transform.forward_shape(base_shape) expanded_base_shape = transform.inverse_shape(shape) if base_shape != expanded_base_shape: base_batch_shape = expanded_base_shape[:len(expanded_base_shape) - base_event_dim] base_distribution = base_distribution.expand(base_batch_shape) reinterpreted_batch_ndims = domain_event_dim - base_event_dim if reinterpreted_batch_ndims > 0: base_distribution = Independent(base_distribution, reinterpreted_batch_ndims) self.base_dist = base_distribution # Compute shapes. event_dim = transform.codomain.event_dim + max(base_event_dim - domain_event_dim, 0) assert len(shape) >= event_dim cut = len(shape) - event_dim batch_shape = shape[:cut] event_shape = shape[cut:] super(TransformedDistribution, self).__init__(batch_shape, event_shape, validate_args=validate_args)
def get_transforms(cache_size): transforms = [ AbsTransform(cache_size=cache_size), ExpTransform(cache_size=cache_size), PowerTransform(exponent=2, cache_size=cache_size), PowerTransform(exponent=torch.tensor(5.).normal_(), cache_size=cache_size), PowerTransform(exponent=torch.tensor(5.).normal_(), cache_size=cache_size), SigmoidTransform(cache_size=cache_size), TanhTransform(cache_size=cache_size), AffineTransform(0, 1, cache_size=cache_size), AffineTransform(1, -2, cache_size=cache_size), AffineTransform(torch.randn(5), torch.randn(5), cache_size=cache_size), AffineTransform(torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size), SoftmaxTransform(cache_size=cache_size), SoftplusTransform(cache_size=cache_size), StickBreakingTransform(cache_size=cache_size), LowerCholeskyTransform(cache_size=cache_size), CorrCholeskyTransform(cache_size=cache_size), ComposeTransform([ AffineTransform(torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size), ]), ComposeTransform([ AffineTransform(torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size), ExpTransform(cache_size=cache_size), ]), ComposeTransform([ AffineTransform(0, 1, cache_size=cache_size), AffineTransform(torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size), AffineTransform(1, -2, cache_size=cache_size), AffineTransform(torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size), ]), ReshapeTransform((4, 5), (2, 5, 2)), IndependentTransform( AffineTransform(torch.randn(5), torch.randn(5), cache_size=cache_size), 1), CumulativeDistributionTransform(Normal(0, 1)), ] transforms += [t.inv for t in transforms] return transforms
def test_compose_transform_shapes(): transform0 = ExpTransform() transform1 = SoftmaxTransform() transform2 = LowerCholeskyTransform() assert transform0.event_dim == 0 assert transform1.event_dim == 1 assert transform2.event_dim == 2 assert ComposeTransform([transform0, transform1]).event_dim == 1 assert ComposeTransform([transform0, transform2]).event_dim == 2 assert ComposeTransform([transform1, transform2]).event_dim == 2
def get_transforms(cache_size): transforms = [ AbsTransform(cache_size=cache_size), ExpTransform(cache_size=cache_size), PowerTransform(exponent=2, cache_size=cache_size), PowerTransform(exponent=torch.tensor(5.).normal_(), cache_size=cache_size), SigmoidTransform(cache_size=cache_size), TanhTransform(cache_size=cache_size), AffineTransform(0, 1, cache_size=cache_size), AffineTransform(1, -2, cache_size=cache_size), AffineTransform(torch.randn(5), torch.randn(5), cache_size=cache_size), AffineTransform(torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size), SoftmaxTransform(cache_size=cache_size), StickBreakingTransform(cache_size=cache_size), LowerCholeskyTransform(cache_size=cache_size), CorrCholeskyTransform(cache_size=cache_size), ComposeTransform([ AffineTransform(torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size), ]), ComposeTransform([ AffineTransform(torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size), ExpTransform(cache_size=cache_size), ]), ComposeTransform([ AffineTransform(0, 1, cache_size=cache_size), AffineTransform(torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size), AffineTransform(1, -2, cache_size=cache_size), AffineTransform(torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size), ]), ] transforms += [t.inv for t in transforms] return transforms
def __call__(self, name, fn, obs): assert obs is None, "TransformReparam does not support observe statements" assert fn.event_dim >= self.transform.event_dim, ( "Cannot transform along batch dimension; " "try converting a batch dimension to an event dimension") # Draw noise from the base distribution. transform = ComposeTransform( [_with_cache(biject_to(fn.support).inv), self.transform]) x_trans = pyro.sample("{}_{}".format(name, self.suffix), dist.TransformedDistribution(fn, transform)) # Differentiably transform. x = transform.inv(x_trans) # should be free due to transform cache # Simulate a pyro.deterministic() site. new_fn = dist.Delta(x, event_dim=fn.event_dim) return new_fn, x
def __call__(self, name, fn, obs): assert obs is None, "TransformReparam does not support observe statements" event_dim = fn.event_dim transform = self.transform with ExitStack() as stack: shift = max(0, transform.event_dim - event_dim) if shift: if not self.experimental_allow_batch: raise ValueError( "Cannot transform along batch dimension; try either" "converting a batch dimension to an event dimension, or " "setting experimental_allow_batch=True.") # Reshape and mute plates using block_plate. from pyro.contrib.forecast.util import ( reshape_batch, reshape_transform_batch, ) old_shape = fn.batch_shape new_shape = old_shape[:-shift] + ( 1, ) * shift + old_shape[-shift:] fn = reshape_batch(fn, new_shape).to_event(shift) transform = reshape_transform_batch(transform, old_shape + fn.event_shape, new_shape + fn.event_shape) for dim in range(-shift, 0): stack.enter_context(block_plate(dim=dim, strict=False)) # Draw noise from the base distribution. transform = ComposeTransform( [biject_to(fn.support).inv.with_cache(), self.transform]) x_trans = pyro.sample("{}_{}".format(name, self.suffix), dist.TransformedDistribution(fn, transform)) # Differentiably transform. x = transform.inv(x_trans) # should be free due to transform cache if shift: x = x.reshape(x.shape[:-2 * shift - event_dim] + x.shape[-shift - event_dim:]) # Simulate a pyro.deterministic() site. new_fn = dist.Delta(x, event_dim=event_dim) return new_fn, x
def _validate_transform(t): """ Ensure that the provided transform can be evaluated. :param t: The transform to be validated :return: (torch.distributions.Transform) a valid transform. """ return ComposeTransform([]) if t is None else t
def __init__(self, w, p, temperature=0.1, validate_args=None): relaxed_bernoulli = RelaxedBernoulli(temperature, p) affine_transform = AffineTransform(0, w) one_minus_p = AffineTransform(1, -1) super(BernoulliDropoutDistribution, self).__init__(relaxed_bernoulli, ComposeTransform([one_minus_p, affine_transform]), validate_args) self.relaxed_bernoulli = relaxed_bernoulli self.affine_transform = affine_transform
def test_compose_reshape(batch_shape): transforms = [ReshapeTransform((), ()), ReshapeTransform((2,), (1, 2)), ReshapeTransform((3, 1, 2), (6,)), ReshapeTransform((6,), (2, 3))] transform = ComposeTransform(transforms) assert transform.codomain.event_dim == 2 assert transform.domain.event_dim == 2 data = torch.randn(batch_shape + (3, 2)) assert transform(data).shape == batch_shape + (2, 3) dist = TransformedDistribution(Normal(data, 1), transforms) assert dist.batch_shape == batch_shape assert dist.event_shape == (2, 3) assert dist.support.event_dim == 2
def test_compose_affine(event_dims): transforms = [AffineTransform(torch.zeros((1,) * e), 1, event_dim=e) for e in event_dims] transform = ComposeTransform(transforms) assert transform.codomain.event_dim == max(event_dims) assert transform.domain.event_dim == max(event_dims) base_dist = Normal(0, 1) if transform.domain.event_dim: base_dist = base_dist.expand((1,) * transform.domain.event_dim) dist = TransformedDistribution(base_dist, transform.parts) assert dist.support.event_dim == max(event_dims) base_dist = Dirichlet(torch.ones(5)) if transform.domain.event_dim > 1: base_dist = base_dist.expand((1,) * (transform.domain.event_dim - 1)) dist = TransformedDistribution(base_dist, transforms) assert dist.support.event_dim == max(1, max(event_dims))
def reshape_transform(transform, shape): # Needed to squash batch dims for testing jacobian if isinstance(transform, AffineTransform): if isinstance(transform.loc, Number): return transform try: return AffineTransform(transform.loc.expand(shape), transform.scale.expand(shape), cache_size=transform._cache_size) except RuntimeError: return AffineTransform(transform.loc.reshape(shape), transform.scale.reshape(shape), cache_size=transform._cache_size) if isinstance(transform, ComposeTransform): reshaped_parts = [] for p in transform.parts: reshaped_parts.append(reshape_transform(p, shape)) return ComposeTransform(reshaped_parts, cache_size=transform._cache_size) if isinstance(transform.inv, AffineTransform): return reshape_transform(transform.inv, shape).inv if isinstance(transform.inv, ComposeTransform): return reshape_transform(transform.inv, shape).inv return transform
def test_transformed_distribution(base_batch_dim, base_event_dim, transform_dim, num_transforms, sample_shape): shape = torch.Size([2, 3, 4, 5]) base_dist = Normal(0, 1) base_dist = base_dist.expand(shape[4 - base_batch_dim - base_event_dim:]) if base_event_dim: base_dist = Independent(base_dist, base_event_dim) transforms = [ AffineTransform(torch.zeros(shape[4 - transform_dim:]), 1), ReshapeTransform((4, 5), (20, )), ReshapeTransform((3, 20), (6, 10)) ] transforms = transforms[:num_transforms] transform = ComposeTransform(transforms) # Check validation in .__init__(). if base_batch_dim + base_event_dim < transform.domain.event_dim: with pytest.raises(ValueError): TransformedDistribution(base_dist, transforms) return d = TransformedDistribution(base_dist, transforms) # Check sampling is sufficiently expanded. x = d.sample(sample_shape) assert x.shape == sample_shape + d.batch_shape + d.event_shape num_unique = len(set(x.reshape(-1).tolist())) assert num_unique >= 0.9 * x.numel() # Check log_prob shape on full samples. log_prob = d.log_prob(x) assert log_prob.shape == sample_shape + d.batch_shape # Check log_prob shape on partial samples. y = x while y.dim() > len(d.event_shape): y = y[0] log_prob = d.log_prob(y) assert log_prob.shape == d.batch_shape
def __init__(self, a, theta, alpha, beta): """ The Amoroso distribution is a very flexible 4 parameter distribution which contains many important exponential families as special cases. *PDF* ``` Amoroso(x | a, θ, α, β) = 1/gamma(α) * abs(β/θ) * ((x - a)/θ)**(α*β-1) * exp(-((x - a)/θ)**β) for: x, a, θ, α, β \in reals, α > 0 support: x >= a if θ > 0 x <= a if θ < 0 ``` """ self.a, self.theta, self.alpha, self.beta = broadcast_all( a, theta, alpha, beta) base_dist = Gamma(self.alpha, 1.) transform = ComposeTransform([ AffineTransform(-self.a / self.theta, 1 / self.theta), PowerTransform(self.beta), ]).inv super().__init__(base_dist, transform)
def _transform_to_positive_ordered_vector(constraint): return ComposeTransform([OrderedTransform(), ExpTransform()])
def _transform_to_corr_matrix(constraint): return ComposeTransform( [CorrLCholeskyTransform(), CorrMatrixCholeskyTransform().inv] )
def _transform_to_positive_definite(constraint): return ComposeTransform([LowerCholeskyTransform(), CholeskyTransform().inv])
def apply(self, msg): name = msg["name"] fn = msg["fn"] value = msg["value"] is_observed = msg["is_observed"] event_dim = fn.event_dim transform = self.transform with ExitStack() as stack: shift = max(0, transform.event_dim - event_dim) if shift: if not self.experimental_allow_batch: raise ValueError( "Cannot transform along batch dimension; try either" "converting a batch dimension to an event dimension, or " "setting experimental_allow_batch=True.") # Reshape and mute plates using block_plate. from pyro.contrib.forecast.util import ( reshape_batch, reshape_transform_batch, ) old_shape = fn.batch_shape new_shape = old_shape[:-shift] + ( 1, ) * shift + old_shape[-shift:] fn = reshape_batch(fn, new_shape).to_event(shift) transform = reshape_transform_batch(transform, old_shape + fn.event_shape, new_shape + fn.event_shape) if value is not None: value = value.reshape(value.shape[:-shift - event_dim] + (1, ) * shift + value.shape[-shift - event_dim:]) for dim in range(-shift, 0): stack.enter_context(block_plate(dim=dim, strict=False)) # Differentiably invert transform. transform = ComposeTransform( [biject_to(fn.support).inv.with_cache(), self.transform]) value_trans = None if value is not None: value_trans = transform(value) # Draw noise from the base distribution. value_trans = pyro.sample( f"{name}_{self.suffix}", dist.TransformedDistribution(fn, transform), obs=value_trans, infer={"is_observed": is_observed}, ) # Differentiably transform. This should be free due to transform cache. if value is None: value = transform.inv(value_trans) if shift: value = value.reshape(value.shape[:-2 * shift - event_dim] + value.shape[-shift - event_dim:]) # Simulate a pyro.deterministic() site. new_fn = dist.Delta(value, event_dim=event_dim) return {"fn": new_fn, "value": value, "is_observed": True}