def _biject_to_simplex(constraint): return transforms.StickBreakingTransform()
@transform_to.register(constraints.less_than) def _transform_to_less_than(constraint): loc = constraint.upper_bound scale = loc.new([-1]).expand_as(loc) return transforms.ComposeTransform( [transforms.ExpTransform(), transforms.AffineTransform(loc, scale)]) biject_to.register(constraints.unit_interval, transforms.SigmoidTransform()) transform_to.register(constraints.unit_interval, transforms.SigmoidTransform()) @biject_to.register(constraints.interval) @transform_to.register(constraints.interval) def _transform_to_interval(constraint): loc = constraint.lower_bound scale = constraint.upper_bound - constraint.lower_bound return transforms.ComposeTransform([ transforms.SigmoidTransform(), transforms.AffineTransform(loc, scale) ]) biject_to.register(constraints.simplex, transforms.StickBreakingTransform()) transform_to.register(constraints.simplex, transforms.BoltzmannTransform()) # TODO define a bijection for LowerCholeskyTransform transform_to.register(constraints.lower_cholesky, transforms.LowerCholeskyTransform())
p1 = dist.Normal(loc, scale).to_event(1) p2 = dist.MultivariateNormal(loc, scale_tril=scale.diag_embed()) loc = torch.randn(batch_shape + (size, )) cov = torch.randn(batch_shape + (size, size)) cov = cov @ cov.transpose(-1, -2) + 0.01 * torch.eye(size) q = dist.MultivariateNormal(loc, covariance_matrix=cov) actual = kl_divergence(p1, q) expected = kl_divergence(p2, q) assert_close(actual, expected) @pytest.mark.parametrize('shape', [(5, ), (4, 5), (2, 3, 5)], ids=str) @pytest.mark.parametrize('event_dim', [0, 1]) @pytest.mark.parametrize( 'transform', [transforms.ExpTransform(), transforms.StickBreakingTransform()]) def test_kl_transformed_transformed(shape, event_dim, transform): p_base = dist.Normal(torch.zeros(shape), torch.ones(shape)).to_event(event_dim) q_base = dist.Normal(torch.ones(shape) * 2, torch.ones(shape)).to_event(event_dim) p = dist.TransformedDistribution(p_base, transform) q = dist.TransformedDistribution(q_base, transform) kl = kl_divergence(q, p) expected_shape = shape[:-1] if max(transform.event_dim, event_dim) == 1 else shape assert kl.shape == expected_shape
@pytest.mark.parametrize("size", [1, 2, 3]) def test_kl_independent_normal_mvn(batch_shape, size): loc = torch.randn(batch_shape + (size,)) scale = torch.randn(batch_shape + (size,)).exp() p1 = dist.Normal(loc, scale).to_event(1) p2 = dist.MultivariateNormal(loc, scale_tril=scale.diag_embed()) loc = torch.randn(batch_shape + (size,)) cov = torch.randn(batch_shape + (size, size)) cov = cov @ cov.transpose(-1, -2) + 0.01 * torch.eye(size) q = dist.MultivariateNormal(loc, covariance_matrix=cov) actual = kl_divergence(p1, q) expected = kl_divergence(p2, q) assert_close(actual, expected) @pytest.mark.parametrize("shape", [(5,), (4, 5), (2, 3, 5)], ids=str) @pytest.mark.parametrize("event_dim", [0, 1]) @pytest.mark.parametrize( "transform", [transforms.ExpTransform(), transforms.StickBreakingTransform()] ) def test_kl_transformed_transformed(shape, event_dim, transform): p_base = dist.Normal(torch.zeros(shape), torch.ones(shape)).to_event(event_dim) q_base = dist.Normal(torch.ones(shape) * 2, torch.ones(shape)).to_event(event_dim) p = dist.TransformedDistribution(p_base, transform) q = dist.TransformedDistribution(q_base, transform) kl = kl_divergence(q, p) expected_shape = shape[:-1] if max(transform.event_dim, event_dim) == 1 else shape assert kl.shape == expected_shape