def random_normal(self, *size, mean=0, std=1): """ Create a point on the manifold, measure is induced by Normal distribution on the tangent space of zero. Parameters ---------- size : shape the desired shape mean : float|tensor mean value for the Normal distribution std : float|tensor std value for the Normal distribution Returns ------- ManifoldTensor random point on the PoincareBall manifold Notes ----- The device and dtype will match the device and dtype of the Manifold """ self._assert_check_shape(size2shape(*size), "x") tens = torch.randn(*size, device=self.c.device, dtype=self.c.dtype) * std + mean return ManifoldTensor(self.expmap0(tens), manifold=self)
def random_normal(self, *size, mean=0.0, std=1.0, device=None, dtype=None): """ Create a point on the manifold, measure is induced by Normal distribution. Parameters ---------- size : shape the desired shape mean : float|tensor mean value for the Normal distribution std : float|tensor std value for the Normal distribution device : torch.device the desired device dtype : torch.dtype the desired dtype Returns ------- ManifoldTensor random point on the manifold """ self._assert_check_shape(size2shape(*size), "x") mean = torch.as_tensor(mean, device=device, dtype=dtype) std = torch.as_tensor(std, device=device, dtype=dtype) tens = std.new_empty(*size).normal_() * std + mean return tensor.ManifoldTensor(tens, manifold=self)
def random_normal(self, *size, mean=0, std=1, dtype=None, device=None) -> "geoopt.ManifoldTensor": """ Create a point on the manifold, measure is induced by Normal distribution on the tangent space of zero. Parameters ---------- size : shape the desired shape mean : float|tensor mean value for the Normal distribution std : float|tensor std value for the Normal distribution dtype: torch.dtype target dtype for sample, if not None, should match Manifold dtype device: torch.device target device for sample, if not None, should match Manifold device Returns ------- ManifoldTensor random point on the PoincareBall manifold Notes ----- The device and dtype will match the device and dtype of the Manifold """ size = size2shape(*size) self._assert_check_shape(size, "x") if device is not None and device != self.c.device: raise ValueError( "`device` does not match the manifold `device`, set the `device` argument to None" ) if dtype is not None and dtype != self.c.dtype: raise ValueError( "`dtype` does not match the manifold `dtype`, set the `dtype` argument to None" ) tens = (torch.randn(size, device=self.c.device, dtype=self.c.dtype) * std / size[-1]**0.5 + mean) return geoopt.ManifoldTensor(self.expmap0(tens), manifold=self)
def random_uniform(self, *size, dtype=None, device=None): """ Uniform random measure on Sphere manifold. Parameters ---------- size : shape the desired output shape dtype : torch.dtype desired dtype device : torch.device desired device Returns ------- ManifoldTensor random point on Sphere manifold Notes ----- In case of projector on the manifold, dtype and device are set automatically and shouldn't be provided. If you provide them, they are checked to match the projector device and dtype """ self._assert_check_shape(size2shape(*size), "x") if self.projector is None: tens = torch.randn(*size, device=device, dtype=dtype) else: if device is not None and device != self.projector.device: raise ValueError( "`device` does not match the projector `device`, set the `device` argument to None" ) if dtype is not None and dtype != self.projector.dtype: raise ValueError( "`dtype` does not match the projector `dtype`, set the `dtype` arguement to None" ) tens = torch.randn(*size, device=self.projector.device, dtype=self.projector.dtype) return ManifoldTensor(self.projx(tens), manifold=self)
def random_naive(self, *size, dtype=None, device=None): """ Naive approach to get random matrix on Stiefel manifold. A helper function to sample a random point on the Stiefel manifold. The measure is non-uniform for this method, but fast to compute. Parameters ---------- size : shape the desired output shape dtype : torch.dtype desired dtype device : torch.device desired device Returns ------- ManifoldTensor random point on Stiefel manifold """ self._assert_check_shape(size2shape(*size), "x") tens = torch.randn(*size, device=device, dtype=dtype) return ManifoldTensor(linalg.qr(tens)[0], manifold=self)