示例#1
0
def test_safe_normalize(dim):
    data = random.normal(random.PRNGKey(0), (100, dim))
    x = safe_normalize(data)
    assert_allclose((x * x).sum(-1), jnp.ones(x.shape[:-1]), rtol=1e-6)
    assert_allclose((x * data).sum(-1)**2, (data * data).sum(-1), rtol=1e-6)

    data = jnp.zeros((10, dim))
    x = safe_normalize(data)
    assert_allclose((x * x).sum(-1), jnp.ones(x.shape[:-1]), rtol=1e-6)
示例#2
0
    def __call__(self, name, fn, obs):
        assert obs is None, "ProjectedNormalReparam does not support observe statements"
        fn, batch_shape, event_dim = self._unwrap(fn)
        assert isinstance(fn, dist.ProjectedNormal)

        # Draw parameter-free noise.
        new_fn = dist.Normal(jnp.zeros(fn.concentration.shape), 1)
        x = numpyro.sample("{}_normal".format(name),
                           self._wrap(new_fn, batch_shape, event_dim))

        # Differentiably transform.
        value = safe_normalize(x + fn.concentration)

        # Simulate a pyro.deterministic() site.
        return None, value
示例#3
0
 def mode(self):
     return safe_normalize(self.concentration)
示例#4
0
 def sample(self, key, sample_shape=()):
     shape = sample_shape + self.batch_shape + self.event_shape
     eps = random.normal(key, shape=shape)
     return safe_normalize(self.concentration + eps)
示例#5
0
 def mean(self):
     """
     Note this is the mean in the sense of a centroid in the submanifold
     that minimizes expected squared geodesic distance.
     """
     return safe_normalize(self.concentration)