def apply(self, msg): name = msg["name"] fn = msg["fn"] value = msg["value"] is_observed = msg["is_observed"] if is_observed: raise NotImplementedError( "ProjectedNormalReparam does not support observe statements" ) fn, event_dim = self._unwrap(fn) assert isinstance(fn, dist.ProjectedNormal) # Differentiably invert transform. value_normal = None if value is not None: # We use an arbitrary injection, which works only for initialization. value_normal = value - fn.concentration # Draw parameter-free noise. new_fn = dist.Normal(torch.zeros_like(fn.concentration), 1).to_event(1) x = pyro.sample( "{}_normal".format(name), self._wrap(new_fn, event_dim), obs=value_normal, infer={"is_observed": is_observed}, ) # Differentiably transform. if value is None: value = safe_normalize(x + fn.concentration) # Simulate a pyro.deterministic() site. new_fn = dist.Delta(value, event_dim=event_dim).mask(False) return {"fn": new_fn, "value": value, "is_observed": True}
def test_sphere_check(dim): data = torch.randn(100, dim) assert not constraints.sphere.check(data).any() data = safe_normalize(data) actual = constraints.sphere.check(data) assert actual.all() assert actual.shape == data.shape[:-1]
def __call__(self, name, fn, obs): fn, event_dim = self._unwrap(fn) assert isinstance(fn, dist.ProjectedNormal) assert obs is None, "ProjectedNormalReparam does not support observe statements" # Draw parameter-free noise. new_fn = dist.Normal(torch.zeros_like(fn.concentration), 1).to_event(1) x = pyro.sample("{}_normal".format(name), self._wrap(new_fn, event_dim)) # Differentiably transform. value = safe_normalize(x + fn.concentration) # Simulate a pyro.deterministic() site. new_fn = dist.Delta(value, event_dim=event_dim).mask(False) return new_fn, value
def _call(self, x): return safe_normalize(x, p=self.p)
def rsample(self, sample_shape=torch.Size()): shape = self._extended_shape(sample_shape) x = self.concentration.new_empty(shape).normal_() x = x + self.concentration x = safe_normalize(x) return x
def mode(self): return safe_normalize(self.concentration)
def mean(self): """ Note this is the mean in the sense of a centroid in the submanifold that minimizes expected squared geodesic distance. """ return safe_normalize(self.concentration)