예제 #1
0
 def _get_transform(self):
     loc = numpyro.param('{}_loc'.format(self.prefix), self._init_latent)
     scale_tril = numpyro.param('{}_scale_tril'.format(self.prefix),
                                np.identity(self.latent_size) *
                                self._init_scale,
                                constraint=constraints.lower_cholesky)
     return MultivariateAffineTransform(loc, scale_tril)
예제 #2
0
 def _get_transform(self, params):
     loc = params['{}_loc'.format(self.prefix)]
     cov_factor = params['{}_cov_factor'.format(self.prefix)]
     scale = params['{}_scale'.format(self.prefix)]
     cov_diag = scale * scale
     cov_factor = cov_factor * scale[..., None]
     scale_tril = dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag).scale_tril
     return MultivariateAffineTransform(loc, scale_tril)
예제 #3
0
    def get_transform(self, params):
        def loss_fn(z):
            params1 = params.copy()
            params1['{}_loc'.format(self.prefix)] = z
            return self._loss_fn(params1)

        loc = params['{}_loc'.format(self.prefix)]
        precision = hessian(loss_fn)(loc)
        scale_tril = cholesky_of_inverse(precision)
        if not_jax_tracer(scale_tril):
            if jnp.any(jnp.isnan(scale_tril)):
                warnings.warn(
                    "Hessian of log posterior at the MAP point is singular. Posterior"
                    " samples from AutoLaplaceApproxmiation will be constant (equal to"
                    " the MAP point).")
        scale_tril = jnp.where(jnp.isnan(scale_tril), 0., scale_tril)
        return MultivariateAffineTransform(loc, scale_tril)
예제 #4
0
    def _get_transform(self, params):
        def loss_fn(z):
            params1 = params.copy()
            params1['{}_loc'.format(self.prefix)] = z
            # we are doing maximum likelihood, so only require `num_particles=1` and an arbitrary rng_key.
            return AutoContinuousELBO().loss(random.PRNGKey(0), params1, self.model, self,
                                             *self._args, **self._kwargs)

        loc = params['{}_loc'.format(self.prefix)]
        precision = hessian(loss_fn)(loc)
        scale_tril = cholesky_of_inverse(precision)
        if not_jax_tracer(scale_tril):
            if np.any(np.isnan(scale_tril)):
                warnings.warn("Hessian of log posterior at the MAP point is singular. Posterior"
                              " samples from AutoLaplaceApproxmiation will be constant (equal to"
                              " the MAP point).")
        scale_tril = np.where(np.isnan(scale_tril), 0., scale_tril)
        return MultivariateAffineTransform(loc, scale_tril)
예제 #5
0
            expected = onp.linalg.slogdet(jax.jacobian(vec_transform)(x))[1]
            inv_expected = onp.linalg.slogdet(jax.jacobian(inv_vec_transform)(y_tril))[1]
        else:
            expected = np.log(np.abs(grad(transform)(x)))
            inv_expected = np.log(np.abs(grad(transform.inv)(y)))

        assert_allclose(actual, expected, atol=1e-6, rtol=1e-6)
        assert_allclose(actual, -inv_expected, atol=1e-6, rtol=1e-6)


# NB: skip transforms which are tested in `test_biject_to`
@pytest.mark.parametrize('transform, event_shape', [
    (PermuteTransform(np.array([3, 0, 4, 1, 2])), (5,)),
    (PowerTransform(2.), ()),
    (MultivariateAffineTransform(np.array([1., 2.]), np.array([[0.6, 0.], [1.5, 0.4]])), (2,))
])
@pytest.mark.parametrize('batch_shape', [(), (1,), (3,), (6,), (3, 1), (1, 3), (5, 3)])
def test_bijective_transforms(transform, event_shape, batch_shape):
    shape = batch_shape + event_shape
    rng_key = random.PRNGKey(0)
    x = biject_to(transform.domain)(random.normal(rng_key, shape))
    y = transform(x)

    # test codomain
    assert_array_equal(transform.codomain(y), np.ones(batch_shape))

    # test inv
    z = transform.inv(y)
    assert_allclose(x, z, atol=1e-6, rtol=1e-6)
예제 #6
0
 def get_transform(self, params):
     posterior = self.get_posterior(params)
     return MultivariateAffineTransform(posterior.loc, posterior.scale_tril)
예제 #7
0
 def get_transform(self, params):
     loc = params['{}_loc'.format(self.prefix)]
     scale_tril = params['{}_scale_tril'.format(self.prefix)]
     return MultivariateAffineTransform(loc, scale_tril)