示例#1
0
    def apply(self, msg):
        name = msg["name"]
        fn = msg["fn"]
        value = msg["value"]
        is_observed = msg["is_observed"]
        if is_observed:
            raise NotImplementedError(
                "ProjectedNormalReparam does not support observe statements"
            )

        fn, event_dim = self._unwrap(fn)
        assert isinstance(fn, dist.ProjectedNormal)

        # Differentiably invert transform.
        value_normal = None
        if value is not None:
            # We use an arbitrary injection, which works only for initialization.
            value_normal = value - fn.concentration

        # Draw parameter-free noise.
        new_fn = dist.Normal(torch.zeros_like(fn.concentration), 1).to_event(1)
        x = pyro.sample(
            "{}_normal".format(name),
            self._wrap(new_fn, event_dim),
            obs=value_normal,
            infer={"is_observed": is_observed},
        )

        # Differentiably transform.
        if value is None:
            value = safe_normalize(x + fn.concentration)

        # Simulate a pyro.deterministic() site.
        new_fn = dist.Delta(value, event_dim=event_dim).mask(False)
        return {"fn": new_fn, "value": value, "is_observed": True}
示例#2
0
def test_sphere_check(dim):
    data = torch.randn(100, dim)
    assert not constraints.sphere.check(data).any()

    data = safe_normalize(data)
    actual = constraints.sphere.check(data)
    assert actual.all()
    assert actual.shape == data.shape[:-1]
示例#3
0
    def __call__(self, name, fn, obs):
        fn, event_dim = self._unwrap(fn)
        assert isinstance(fn, dist.ProjectedNormal)
        assert obs is None, "ProjectedNormalReparam does not support observe statements"

        # Draw parameter-free noise.
        new_fn = dist.Normal(torch.zeros_like(fn.concentration), 1).to_event(1)
        x = pyro.sample("{}_normal".format(name),
                        self._wrap(new_fn, event_dim))

        # Differentiably transform.
        value = safe_normalize(x + fn.concentration)

        # Simulate a pyro.deterministic() site.
        new_fn = dist.Delta(value, event_dim=event_dim).mask(False)
        return new_fn, value
示例#4
0
 def _call(self, x):
     return safe_normalize(x, p=self.p)
示例#5
0
 def rsample(self, sample_shape=torch.Size()):
     shape = self._extended_shape(sample_shape)
     x = self.concentration.new_empty(shape).normal_()
     x = x + self.concentration
     x = safe_normalize(x)
     return x
示例#6
0
 def mode(self):
     return safe_normalize(self.concentration)
示例#7
0
 def mean(self):
     """
     Note this is the mean in the sense of a centroid in the submanifold
     that minimizes expected squared geodesic distance.
     """
     return safe_normalize(self.concentration)